[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135636 >From aa8a667f206874af3b26811ec04d58be12ad43de Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH 1/3] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 10557658d5d7d..b496ee0114910 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 0ee6dce9ee94b..293e01a5bf4d4 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPattern
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135636 >From aa8a667f206874af3b26811ec04d58be12ad43de Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH 1/3] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 10557658d5d7d..b496ee0114910 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 0ee6dce9ee94b..293e01a5bf4d4 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPattern
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); + +if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + +// M, N, and K are the conventional names for matrix dimensions in the +// context of matrix multiplication. +auto M = lhsType.getDimSize(0); +auto N = rhsType.getDimSize(0); +auto K = rhsType.getDimSize(1); + +if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || +N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + +// Check permutation maps. For now only accept +// lhs: (d0, d1, d2) -> (d0, d2) +// rhs: (d0, d1, d2) -> (d1, d2) +// acc: (d0, d1, d2) -> (d0, d1) +// Note: RHS is transposed. +if (op.getIndexingMapsArray()[0] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[1] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[2] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + +// Check iterator types for matrix multiplication. +auto itTypes = op.getIteratorTypesArray(); +if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || +itTypes[1] != vector::IteratorType
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135636 >From f397467bc167d94a28a919a45c009a8f08b6351b Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH 1/2] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 10557658d5d7d..b496ee0114910 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 0ee6dce9ee94b..293e01a5bf4d4 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPattern
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135636 >From f397467bc167d94a28a919a45c009a8f08b6351b Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH 1/2] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 10557658d5d7d..b496ee0114910 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 0ee6dce9ee94b..293e01a5bf4d4 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPattern
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at momchil-velikov wrote: Done (in the top-level description). https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern momchil-velikov wrote: Done. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); + +if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + +// M, N, and K are the conventional names for matrix dimensions in the +// context of matrix multiplication. +auto M = lhsType.getDimSize(0); +auto N = rhsType.getDimSize(0); +auto K = rhsType.getDimSize(1); + +if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || +N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + +// Check permutation maps. For now only accept +// lhs: (d0, d1, d2) -> (d0, d2) +// rhs: (d0, d1, d2) -> (d1, d2) +// acc: (d0, d1, d2) -> (d0, d1) +// Note: RHS is transposed. +if (op.getIndexingMapsArray()[0] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[1] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[2] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + +// Check iterator types for matrix multiplication. +auto itTypes = op.getIteratorTypesArray(); +if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || +itTypes[1] != vector::IteratorType
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); + +if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + +// M, N, and K are the conventional names for matrix dimensions in the +// context of matrix multiplication. +auto M = lhsType.getDimSize(0); +auto N = rhsType.getDimSize(0); +auto K = rhsType.getDimSize(1); + +if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || +N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + +// Check permutation maps. For now only accept +// lhs: (d0, d1, d2) -> (d0, d2) +// rhs: (d0, d1, d2) -> (d1, d2) +// acc: (d0, d1, d2) -> (d0, d1) +// Note: RHS is transposed. +if (op.getIndexingMapsArray()[0] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[1] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[2] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + +// Check iterator types for matrix multiplication. +auto itTypes = op.getIteratorTypesArray(); +if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || +itTypes[1] != vector::IteratorType
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at banach-space wrote: Perhaps just expand this comment a bit (e.g. by noting that MMT4D is the main use-case ATM)? https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at momchil-velikov wrote: There's no dependency on MMT4D - this comment is merely a rationale why we have chosen these particular operand shapes. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. momchil-velikov wrote: The `vector.contract` implicitly sign-extends its operands, so it does not need to by accompanied by explicit extend operations. I'll add code to handle this case too. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || momchil-velikov wrote: Done https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> momchil-velikov wrote: Because it does not need to be checked at runtime. But fair enough, it should be a `static_assert`, as there's no ambiguity to resolve, and such use is quite common across the code base, e.g. ``` static_assert(llvm::is_one_of::value, "applies to only pack or unpack operations"); ``` https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/dcaballe commented: Really excited to see this! I'll take a look in the next iteration. Thanks! https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); + +if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + +// M, N, and K are the conventional names for matrix dimensions in the +// context of matrix multiplication. +auto M = lhsType.getDimSize(0); +auto N = rhsType.getDimSize(0); +auto K = rhsType.getDimSize(1); + +if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || +N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + +// Check permutation maps. For now only accept +// lhs: (d0, d1, d2) -> (d0, d2) +// rhs: (d0, d1, d2) -> (d1, d2) +// acc: (d0, d1, d2) -> (d0, d1) +// Note: RHS is transposed. +if (op.getIndexingMapsArray()[0] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[1] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[2] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + +// Check iterator types for matrix multiplication. +auto itTypes = op.getIteratorTypesArray(); +if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || +itTypes[1] != vector::IteratorType
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/banach-space commented: Thanks Momchil - this is great! I skimmed through the pattern logic, and it's very neatly written. It's actually quite easy to follow, despite the underlying logic being a bit convoluted - well done! I've left a few minor suggestions, but nothing major. Also, it seems like we should be able to extend this fairly easily to support NEON as well. Worth thinking about 🙂 Now, overall this patch is quite large, and I’d suggest extracting the end-to-end / integration tests into a separate PR. Additionally, the remaining tests currently use `--convert-vector-to-llvm=`, which lowers all the way to LLVM (i.e., it exercises a lot of patterns). Instead, I’d recommend testing `LowerContractionToSVEI8MMPattern` in isolation and only verifying that the correct sequence of ArmSVE ops (plus some Vector ops) is generated - for example: ```mlir (...) %33 = arm_sve.smmla %23, %7, %15 : vector<[16]xi8> to vector<[4]xi32> %34 = arm_sve.smmla %24, %7, %16 : vector<[16]xi8> to vector<[4]xi32> %35 = arm_sve.smmla %31, %13, %15 : vector<[16]xi8> to vector<[4]xi32> %36 = arm_sve.smmla %32, %13, %16 : vector<[16]xi8> to vector<[4]xi32> ``` That way, we will: * reduce noise in the test output (by focusing on a single pattern), * simplify expected output (fewer ops to match), * avoid re-testing functionality already covered elsewhere (e.g., `arm_sve.smmla` → `arm_sve.intr.smmla` lowering). Btw, this is already looking great, and I know I’m asking for a bit of a rewrite (especially around the tests), but I really think it’ll help with long-term maintainability. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> banach-space wrote: Why not simple `isa(v.getDefinitionOp())` inside the function instead of this? That's more common from what I've seen (there's very little SFINAE in the Dialect code). https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || banach-space wrote: IIRC, inputs to `vector.contract` are required to be vectors, hence `lhsType.hasRank()` should always be true, no? https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); + +if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + +// M, N, and K are the conventional names for matrix dimensions in the +// context of matrix multiplication. +auto M = lhsType.getDimSize(0); +auto N = rhsType.getDimSize(0); +auto K = rhsType.getDimSize(1); + +if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || +N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + +// Check permutation maps. For now only accept +// lhs: (d0, d1, d2) -> (d0, d2) +// rhs: (d0, d1, d2) -> (d1, d2) +// acc: (d0, d1, d2) -> (d0, d1) +// Note: RHS is transposed. +if (op.getIndexingMapsArray()[0] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[1] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[2] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + +// Check iterator types for matrix multiplication. +auto itTypes = op.getIteratorTypesArray(); +if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || +itTypes[1] != vector::IteratorType
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); banach-space wrote: Could you use `notifyMatchFailure` with some descriptive error message instead? Thanks! Some comment for other instances of `failure`. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/banach-space edited https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- banach-space wrote: ```suggestion //===--===//``` https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); + +if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + +// M, N, and K are the conventional names for matrix dimensions in the +// context of matrix multiplication. +auto M = lhsType.getDimSize(0); +auto N = rhsType.getDimSize(0); +auto K = rhsType.getDimSize(1); + +if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || +N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + +// Check permutation maps. For now only accept +// lhs: (d0, d1, d2) -> (d0, d2) +// rhs: (d0, d1, d2) -> (d1, d2) +// acc: (d0, d1, d2) -> (d0, d1) +// Note: RHS is transposed. +if (op.getIndexingMapsArray()[0] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[1] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[2] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + +// Check iterator types for matrix multiplication. +auto itTypes = op.getIteratorTypesArray(); +if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || +itTypes[1] != vector::IteratorType
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at banach-space wrote: [nit] We shouldn't be concerned with MMT4D in this dialect - it's a much higher-level abstraction and this logic should be valid irrespective of how the input is generated. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern banach-space wrote: It's a very long pattern. Could you document the high-level logic? https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. banach-space wrote: Could you add a note that `vector.contract` needs to be accompanied by `arith.extsi` (or `arith.extui`) Ops? Also, is I8MM the official name? Shouldn't that be FEAT_I8MM? Basically, could we document a bit more? https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
momchil-velikov wrote: > One high-level question - would sharing some code between NEON and SVE be > possible? No, I can't see it happening and resulting in less, or simpler, or easier to maintain code. However, it might be possible to add Neon lowering to this patch and see if the result is any good. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135636 >From 8e87a7f3b1438d9542d28c90eb9593ebe8cf6500 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..930d8b44abca0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 7082b92c95d1d..1e6c8122b1d0e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(pa
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135636 >From 8e87a7f3b1438d9542d28c90eb9593ebe8cf6500 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..930d8b44abca0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 7082b92c95d1d..1e6c8122b1d0e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(pa
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/banach-space commented: Thanks! This one is a bit longer, so I may need to wait till Thursday before I can review. One high-level question - would sharing some code between NEON and SVE be possible? https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/135636 Supersedes https://github.com/llvm/llvm-project/pull/135359 >From 2e61d3ee7b9ac88ae1be8ca248dad1a0880ccff4 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..930d8b44abca0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 7082b92c95d1d..1e6c8122b1d0e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
llvmbot wrote: @llvm/pr-subscribers-mlir Author: Momchil Velikov (momchil-velikov) Changes Supersedes https://github.com/llvm/llvm-project/pull/135359 --- Patch is 77.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135636.diff 16 Files Affected: - (modified) mlir/include/mlir/Conversion/Passes.td (+4) - (modified) mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h (+3) - (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (+1) - (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+7) - (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-1) - (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+1) - (added) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+304) - (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir (+94) - (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir (+85) - (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir (+94) - (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir (+95) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir (+117) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir (+159) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir (+118) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir (+119) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir (+117) ``diff diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..930d8b44abca0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 7082b92c95d1d..1e6c8122b1d0e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); +if (armI8MM) { + if (armNeon) +arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns); + if (armSVE) +populateLowerContractionToSVEI8MMPatternPatterns(patterns); +} (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp index 2a1271dfd6bdf..e807b
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
llvmbot wrote: @llvm/pr-subscribers-mlir-neon Author: Momchil Velikov (momchil-velikov) Changes Supersedes https://github.com/llvm/llvm-project/pull/135359 --- Patch is 77.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135636.diff 16 Files Affected: - (modified) mlir/include/mlir/Conversion/Passes.td (+4) - (modified) mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h (+3) - (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (+1) - (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+7) - (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-1) - (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+1) - (added) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+304) - (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir (+94) - (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir (+85) - (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir (+94) - (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir (+95) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir (+117) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir (+159) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir (+118) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir (+119) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir (+117) ``diff diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..930d8b44abca0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 7082b92c95d1d..1e6c8122b1d0e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); +if (armI8MM) { + if (armNeon) +arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns); + if (armSVE) +populateLowerContractionToSVEI8MMPatternPatterns(patterns); +} (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp index 2a1271dfd6bdf..
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
llvmbot wrote: @llvm/pr-subscribers-mlir-sve Author: Momchil Velikov (momchil-velikov) Changes Supersedes https://github.com/llvm/llvm-project/pull/135359 --- Patch is 77.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135636.diff 16 Files Affected: - (modified) mlir/include/mlir/Conversion/Passes.td (+4) - (modified) mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h (+3) - (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (+1) - (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+7) - (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-1) - (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+1) - (added) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+304) - (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir (+94) - (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir (+85) - (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir (+94) - (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir (+95) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir (+117) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir (+159) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir (+118) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir (+119) - (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir (+117) ``diff diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..930d8b44abca0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 7082b92c95d1d..1e6c8122b1d0e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); +if (armI8MM) { + if (armNeon) +arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns); + if (armSVE) +populateLowerContractionToSVEI8MMPatternPatterns(patterns); +} (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp index 2a1271dfd6bdf..e