[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-15 Thread Momchil Velikov via llvm-branch-commits

https://github.com/momchil-velikov updated 
https://github.com/llvm/llvm-project/pull/135636

>From aa8a667f206874af3b26811ec04d58be12ad43de Mon Sep 17 00:00:00 2001
From: Momchil Velikov 
Date: Tue, 8 Apr 2025 14:43:54 +
Subject: [PATCH 1/3] [MLIR][ArmSVE] Add initial lowering of `vector.contract`
 to SVE `*MMLA` instructions

---
 mlir/include/mlir/Conversion/Passes.td|   4 +
 .../Dialect/ArmSVE/Transforms/Transforms.h|   3 +
 .../Conversion/VectorToLLVM/CMakeLists.txt|   1 +
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   7 +
 .../LowerContractionToSMMLAPattern.cpp|   5 +-
 .../Dialect/ArmSVE/Transforms/CMakeLists.txt  |   1 +
 .../LowerContractionToSVEI8MMPattern.cpp  | 304 ++
 .../Vector/CPU/ArmSVE/vector-smmla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-summla.mlir  |  85 +
 .../Vector/CPU/ArmSVE/vector-ummla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-usmmla.mlir  |  95 ++
 .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir   | 117 +++
 .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir   | 159 +
 .../CPU/ArmSVE/contraction-summla-4x8x4.mlir  | 118 +++
 .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir   | 119 +++
 .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir  | 117 +++
 16 files changed, 1322 insertions(+), 1 deletion(-)
 create mode 100644 
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td 
b/mlir/include/mlir/Conversion/Passes.td
index 10557658d5d7d..b496ee0114910 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : 
Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+Option<"armI8MM", "enable-arm-i8mm",
+   "bool", /*default=*/"false",
+   "Enables the use of Arm FEAT_I8MM instructions while lowering "
+   "the vector dialect.">,
 Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h 
b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
 const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt 
b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp 
b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 0ee6dce9ee94b..293e01a5bf4d4 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
 populateVectorStepLoweringPattern

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-15 Thread Momchil Velikov via llvm-branch-commits

https://github.com/momchil-velikov updated 
https://github.com/llvm/llvm-project/pull/135636

>From aa8a667f206874af3b26811ec04d58be12ad43de Mon Sep 17 00:00:00 2001
From: Momchil Velikov 
Date: Tue, 8 Apr 2025 14:43:54 +
Subject: [PATCH 1/3] [MLIR][ArmSVE] Add initial lowering of `vector.contract`
 to SVE `*MMLA` instructions

---
 mlir/include/mlir/Conversion/Passes.td|   4 +
 .../Dialect/ArmSVE/Transforms/Transforms.h|   3 +
 .../Conversion/VectorToLLVM/CMakeLists.txt|   1 +
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   7 +
 .../LowerContractionToSMMLAPattern.cpp|   5 +-
 .../Dialect/ArmSVE/Transforms/CMakeLists.txt  |   1 +
 .../LowerContractionToSVEI8MMPattern.cpp  | 304 ++
 .../Vector/CPU/ArmSVE/vector-smmla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-summla.mlir  |  85 +
 .../Vector/CPU/ArmSVE/vector-ummla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-usmmla.mlir  |  95 ++
 .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir   | 117 +++
 .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir   | 159 +
 .../CPU/ArmSVE/contraction-summla-4x8x4.mlir  | 118 +++
 .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir   | 119 +++
 .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir  | 117 +++
 16 files changed, 1322 insertions(+), 1 deletion(-)
 create mode 100644 
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td 
b/mlir/include/mlir/Conversion/Passes.td
index 10557658d5d7d..b496ee0114910 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : 
Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+Option<"armI8MM", "enable-arm-i8mm",
+   "bool", /*default=*/"false",
+   "Enables the use of Arm FEAT_I8MM instructions while lowering "
+   "the vector dialect.">,
 Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h 
b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
 const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt 
b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp 
b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 0ee6dce9ee94b..293e01a5bf4d4 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
 populateVectorStepLoweringPattern

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-14 Thread Momchil Velikov via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+: public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+PatternRewriter &rewriter) const override {
+
+Location loc = op.getLoc();
+mlir::VectorType lhsType = op.getLhsType();
+mlir::VectorType rhsType = op.getRhsType();
+
+// For now handle LHS and RHS<8x[N]> - these are the types we
+// eventually expect from MMT4D. M and N dimensions must be even and at
+// least 2.
+if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+rhsType.getRank() != 2)
+  return failure();
+
+if (lhsType.isScalable() || !rhsType.isScalable())
+  return failure();
+
+// M, N, and K are the conventional names for matrix dimensions in the
+// context of matrix multiplication.
+auto M = lhsType.getDimSize(0);
+auto N = rhsType.getDimSize(0);
+auto K = rhsType.getDimSize(1);
+
+if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+N % 2 != 0 || !rhsType.getScalableDims()[0])
+  return failure();
+
+// Check permutation maps. For now only accept
+//   lhs: (d0, d1, d2) -> (d0, d2)
+//   rhs: (d0, d1, d2) -> (d1, d2)
+//   acc: (d0, d1, d2) -> (d0, d1)
+// Note: RHS is transposed.
+if (op.getIndexingMapsArray()[0] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+op.getIndexingMapsArray()[1] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+op.getIndexingMapsArray()[2] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+  return failure();
+
+// Check iterator types for matrix multiplication.
+auto itTypes = op.getIteratorTypesArray();
+if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+itTypes[1] != vector::IteratorType

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-14 Thread Momchil Velikov via llvm-branch-commits

https://github.com/momchil-velikov updated 
https://github.com/llvm/llvm-project/pull/135636

>From f397467bc167d94a28a919a45c009a8f08b6351b Mon Sep 17 00:00:00 2001
From: Momchil Velikov 
Date: Tue, 8 Apr 2025 14:43:54 +
Subject: [PATCH 1/2] [MLIR][ArmSVE] Add initial lowering of `vector.contract`
 to SVE `*MMLA` instructions

---
 mlir/include/mlir/Conversion/Passes.td|   4 +
 .../Dialect/ArmSVE/Transforms/Transforms.h|   3 +
 .../Conversion/VectorToLLVM/CMakeLists.txt|   1 +
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   7 +
 .../LowerContractionToSMMLAPattern.cpp|   5 +-
 .../Dialect/ArmSVE/Transforms/CMakeLists.txt  |   1 +
 .../LowerContractionToSVEI8MMPattern.cpp  | 304 ++
 .../Vector/CPU/ArmSVE/vector-smmla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-summla.mlir  |  85 +
 .../Vector/CPU/ArmSVE/vector-ummla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-usmmla.mlir  |  95 ++
 .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir   | 117 +++
 .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir   | 159 +
 .../CPU/ArmSVE/contraction-summla-4x8x4.mlir  | 118 +++
 .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir   | 119 +++
 .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir  | 117 +++
 16 files changed, 1322 insertions(+), 1 deletion(-)
 create mode 100644 
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td 
b/mlir/include/mlir/Conversion/Passes.td
index 10557658d5d7d..b496ee0114910 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : 
Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+Option<"armI8MM", "enable-arm-i8mm",
+   "bool", /*default=*/"false",
+   "Enables the use of Arm FEAT_I8MM instructions while lowering "
+   "the vector dialect.">,
 Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h 
b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
 const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt 
b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp 
b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 0ee6dce9ee94b..293e01a5bf4d4 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
 populateVectorStepLoweringPattern

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-14 Thread Momchil Velikov via llvm-branch-commits

https://github.com/momchil-velikov updated 
https://github.com/llvm/llvm-project/pull/135636

>From f397467bc167d94a28a919a45c009a8f08b6351b Mon Sep 17 00:00:00 2001
From: Momchil Velikov 
Date: Tue, 8 Apr 2025 14:43:54 +
Subject: [PATCH 1/2] [MLIR][ArmSVE] Add initial lowering of `vector.contract`
 to SVE `*MMLA` instructions

---
 mlir/include/mlir/Conversion/Passes.td|   4 +
 .../Dialect/ArmSVE/Transforms/Transforms.h|   3 +
 .../Conversion/VectorToLLVM/CMakeLists.txt|   1 +
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   7 +
 .../LowerContractionToSMMLAPattern.cpp|   5 +-
 .../Dialect/ArmSVE/Transforms/CMakeLists.txt  |   1 +
 .../LowerContractionToSVEI8MMPattern.cpp  | 304 ++
 .../Vector/CPU/ArmSVE/vector-smmla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-summla.mlir  |  85 +
 .../Vector/CPU/ArmSVE/vector-ummla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-usmmla.mlir  |  95 ++
 .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir   | 117 +++
 .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir   | 159 +
 .../CPU/ArmSVE/contraction-summla-4x8x4.mlir  | 118 +++
 .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir   | 119 +++
 .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir  | 117 +++
 16 files changed, 1322 insertions(+), 1 deletion(-)
 create mode 100644 
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td 
b/mlir/include/mlir/Conversion/Passes.td
index 10557658d5d7d..b496ee0114910 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : 
Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+Option<"armI8MM", "enable-arm-i8mm",
+   "bool", /*default=*/"false",
+   "Enables the use of Arm FEAT_I8MM instructions while lowering "
+   "the vector dialect.">,
 Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h 
b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
 const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt 
b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp 
b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 0ee6dce9ee94b..293e01a5bf4d4 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
 populateVectorStepLoweringPattern

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-14 Thread Momchil Velikov via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+: public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+PatternRewriter &rewriter) const override {
+
+Location loc = op.getLoc();
+mlir::VectorType lhsType = op.getLhsType();
+mlir::VectorType rhsType = op.getRhsType();
+
+// For now handle LHS and RHS<8x[N]> - these are the types we
+// eventually expect from MMT4D. M and N dimensions must be even and at

momchil-velikov wrote:

Done (in the top-level description).

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-14 Thread Momchil Velikov via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern

momchil-velikov wrote:

Done.

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-14 Thread Momchil Velikov via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+: public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+PatternRewriter &rewriter) const override {
+
+Location loc = op.getLoc();
+mlir::VectorType lhsType = op.getLhsType();
+mlir::VectorType rhsType = op.getRhsType();
+
+// For now handle LHS and RHS<8x[N]> - these are the types we
+// eventually expect from MMT4D. M and N dimensions must be even and at
+// least 2.
+if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+rhsType.getRank() != 2)
+  return failure();
+
+if (lhsType.isScalable() || !rhsType.isScalable())
+  return failure();
+
+// M, N, and K are the conventional names for matrix dimensions in the
+// context of matrix multiplication.
+auto M = lhsType.getDimSize(0);
+auto N = rhsType.getDimSize(0);
+auto K = rhsType.getDimSize(1);
+
+if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+N % 2 != 0 || !rhsType.getScalableDims()[0])
+  return failure();
+
+// Check permutation maps. For now only accept
+//   lhs: (d0, d1, d2) -> (d0, d2)
+//   rhs: (d0, d1, d2) -> (d1, d2)
+//   acc: (d0, d1, d2) -> (d0, d1)
+// Note: RHS is transposed.
+if (op.getIndexingMapsArray()[0] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+op.getIndexingMapsArray()[1] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+op.getIndexingMapsArray()[2] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+  return failure();
+
+// Check iterator types for matrix multiplication.
+auto itTypes = op.getIteratorTypesArray();
+if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+itTypes[1] != vector::IteratorType

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-14 Thread Momchil Velikov via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+: public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+PatternRewriter &rewriter) const override {
+
+Location loc = op.getLoc();
+mlir::VectorType lhsType = op.getLhsType();
+mlir::VectorType rhsType = op.getRhsType();
+
+// For now handle LHS and RHS<8x[N]> - these are the types we
+// eventually expect from MMT4D. M and N dimensions must be even and at
+// least 2.
+if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+rhsType.getRank() != 2)
+  return failure();
+
+if (lhsType.isScalable() || !rhsType.isScalable())
+  return failure();
+
+// M, N, and K are the conventional names for matrix dimensions in the
+// context of matrix multiplication.
+auto M = lhsType.getDimSize(0);
+auto N = rhsType.getDimSize(0);
+auto K = rhsType.getDimSize(1);
+
+if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+N % 2 != 0 || !rhsType.getScalableDims()[0])
+  return failure();
+
+// Check permutation maps. For now only accept
+//   lhs: (d0, d1, d2) -> (d0, d2)
+//   rhs: (d0, d1, d2) -> (d1, d2)
+//   acc: (d0, d1, d2) -> (d0, d1)
+// Note: RHS is transposed.
+if (op.getIndexingMapsArray()[0] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+op.getIndexingMapsArray()[1] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+op.getIndexingMapsArray()[2] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+  return failure();
+
+// Check iterator types for matrix multiplication.
+auto itTypes = op.getIteratorTypesArray();
+if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+itTypes[1] != vector::IteratorType

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-14 Thread Andrzej Warzyński via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+: public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+PatternRewriter &rewriter) const override {
+
+Location loc = op.getLoc();
+mlir::VectorType lhsType = op.getLhsType();
+mlir::VectorType rhsType = op.getRhsType();
+
+// For now handle LHS and RHS<8x[N]> - these are the types we
+// eventually expect from MMT4D. M and N dimensions must be even and at

banach-space wrote:

Perhaps just expand this comment a bit (e.g. by noting that MMT4D is the main 
use-case ATM)?

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-14 Thread Momchil Velikov via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+: public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+PatternRewriter &rewriter) const override {
+
+Location loc = op.getLoc();
+mlir::VectorType lhsType = op.getLhsType();
+mlir::VectorType rhsType = op.getRhsType();
+
+// For now handle LHS and RHS<8x[N]> - these are the types we
+// eventually expect from MMT4D. M and N dimensions must be even and at

momchil-velikov wrote:

There's no dependency on MMT4D - this comment is merely a rationale why we have 
chosen these particular operand shapes.

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-14 Thread Momchil Velikov via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.

momchil-velikov wrote:

The `vector.contract` implicitly sign-extends its operands, so it does not need 
to by accompanied by explicit extend operations. I'll add code to handle this 
case too.

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-14 Thread Momchil Velikov via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+: public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+PatternRewriter &rewriter) const override {
+
+Location loc = op.getLoc();
+mlir::VectorType lhsType = op.getLhsType();
+mlir::VectorType rhsType = op.getRhsType();
+
+// For now handle LHS and RHS<8x[N]> - these are the types we
+// eventually expect from MMT4D. M and N dimensions must be even and at
+// least 2.
+if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||

momchil-velikov wrote:

Done

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-14 Thread Momchil Velikov via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>

momchil-velikov wrote:

Because it does not need to be checked at runtime. But fair enough, it should 
be a `static_assert`, as there's no ambiguity to resolve, and such use is quite 
common across the code base, e.g.
```
  static_assert(llvm::is_one_of::value,
"applies to only pack or unpack operations");
```

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-05-09 Thread Diego Caballero via llvm-branch-commits

https://github.com/dcaballe commented:

Really excited to see this! I'll take a look in the next iteration. Thanks!

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-23 Thread Andrzej Warzyński via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+: public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+PatternRewriter &rewriter) const override {
+
+Location loc = op.getLoc();
+mlir::VectorType lhsType = op.getLhsType();
+mlir::VectorType rhsType = op.getRhsType();
+
+// For now handle LHS and RHS<8x[N]> - these are the types we
+// eventually expect from MMT4D. M and N dimensions must be even and at
+// least 2.
+if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+rhsType.getRank() != 2)
+  return failure();
+
+if (lhsType.isScalable() || !rhsType.isScalable())
+  return failure();
+
+// M, N, and K are the conventional names for matrix dimensions in the
+// context of matrix multiplication.
+auto M = lhsType.getDimSize(0);
+auto N = rhsType.getDimSize(0);
+auto K = rhsType.getDimSize(1);
+
+if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+N % 2 != 0 || !rhsType.getScalableDims()[0])
+  return failure();
+
+// Check permutation maps. For now only accept
+//   lhs: (d0, d1, d2) -> (d0, d2)
+//   rhs: (d0, d1, d2) -> (d1, d2)
+//   acc: (d0, d1, d2) -> (d0, d1)
+// Note: RHS is transposed.
+if (op.getIndexingMapsArray()[0] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+op.getIndexingMapsArray()[1] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+op.getIndexingMapsArray()[2] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+  return failure();
+
+// Check iterator types for matrix multiplication.
+auto itTypes = op.getIteratorTypesArray();
+if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+itTypes[1] != vector::IteratorType

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-23 Thread Andrzej Warzyński via llvm-branch-commits

https://github.com/banach-space commented:

Thanks Momchil - this is great!

I skimmed through the pattern logic, and it's very neatly written. It's 
actually quite easy to follow, despite the underlying logic being a bit 
convoluted - well done! I've left a few minor suggestions, but nothing major.

Also, it seems like we should be able to extend this fairly easily to support 
NEON as well. Worth thinking about 🙂

Now, overall this patch is quite large, and I’d suggest extracting the 
end-to-end / integration tests into a separate PR. Additionally, the remaining 
tests currently use `--convert-vector-to-llvm=`, which lowers all the way to 
LLVM (i.e., it exercises a lot of patterns). Instead, I’d recommend testing 
`LowerContractionToSVEI8MMPattern` in isolation and only verifying that the 
correct sequence of ArmSVE ops (plus some Vector ops) is generated - for 
example:

```mlir
(...)
%33 = arm_sve.smmla %23, %7, %15 : vector<[16]xi8> to vector<[4]xi32>
%34 = arm_sve.smmla %24, %7, %16 : vector<[16]xi8> to vector<[4]xi32>
%35 = arm_sve.smmla %31, %13, %15 : vector<[16]xi8> to vector<[4]xi32>
%36 = arm_sve.smmla %32, %13, %16 : vector<[16]xi8> to vector<[4]xi32>
```

That way, we will:
* reduce noise in the test output (by focusing on a single pattern),
* simplify expected output (fewer ops to match),
* avoid re-testing functionality already covered elsewhere (e.g., 
`arm_sve.smmla` → `arm_sve.intr.smmla` lowering).

Btw, this is already looking great, and I know I’m asking for a bit of a 
rewrite (especially around the tests), but I really think it’ll help with 
long-term maintainability.

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-23 Thread Andrzej Warzyński via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>

banach-space wrote:

Why not simple `isa(v.getDefinitionOp())` 
inside the function instead of this? That's more common from what I've seen 
(there's very little SFINAE in the Dialect code).

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-23 Thread Andrzej Warzyński via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+: public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+PatternRewriter &rewriter) const override {
+
+Location loc = op.getLoc();
+mlir::VectorType lhsType = op.getLhsType();
+mlir::VectorType rhsType = op.getRhsType();
+
+// For now handle LHS and RHS<8x[N]> - these are the types we
+// eventually expect from MMT4D. M and N dimensions must be even and at
+// least 2.
+if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||

banach-space wrote:

IIRC, inputs to `vector.contract` are required to be vectors, hence 
`lhsType.hasRank()` should always be true, no?

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-23 Thread Andrzej Warzyński via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+: public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+PatternRewriter &rewriter) const override {
+
+Location loc = op.getLoc();
+mlir::VectorType lhsType = op.getLhsType();
+mlir::VectorType rhsType = op.getRhsType();
+
+// For now handle LHS and RHS<8x[N]> - these are the types we
+// eventually expect from MMT4D. M and N dimensions must be even and at
+// least 2.
+if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+rhsType.getRank() != 2)
+  return failure();
+
+if (lhsType.isScalable() || !rhsType.isScalable())
+  return failure();
+
+// M, N, and K are the conventional names for matrix dimensions in the
+// context of matrix multiplication.
+auto M = lhsType.getDimSize(0);
+auto N = rhsType.getDimSize(0);
+auto K = rhsType.getDimSize(1);
+
+if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+N % 2 != 0 || !rhsType.getScalableDims()[0])
+  return failure();
+
+// Check permutation maps. For now only accept
+//   lhs: (d0, d1, d2) -> (d0, d2)
+//   rhs: (d0, d1, d2) -> (d1, d2)
+//   acc: (d0, d1, d2) -> (d0, d1)
+// Note: RHS is transposed.
+if (op.getIndexingMapsArray()[0] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+op.getIndexingMapsArray()[1] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+op.getIndexingMapsArray()[2] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+  return failure();
+
+// Check iterator types for matrix multiplication.
+auto itTypes = op.getIteratorTypesArray();
+if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+itTypes[1] != vector::IteratorType

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-23 Thread Andrzej Warzyński via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+: public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+PatternRewriter &rewriter) const override {
+
+Location loc = op.getLoc();
+mlir::VectorType lhsType = op.getLhsType();
+mlir::VectorType rhsType = op.getRhsType();
+
+// For now handle LHS and RHS<8x[N]> - these are the types we
+// eventually expect from MMT4D. M and N dimensions must be even and at
+// least 2.
+if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+rhsType.getRank() != 2)
+  return failure();

banach-space wrote:

Could you use `notifyMatchFailure` with some descriptive error message instead? 
Thanks! Some comment for other instances of `failure`.

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-23 Thread Andrzej Warzyński via llvm-branch-commits

https://github.com/banach-space edited 
https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-23 Thread Andrzej Warzyński via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---

banach-space wrote:

```suggestion
//===--===//```

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-23 Thread Andrzej Warzyński via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+: public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+PatternRewriter &rewriter) const override {
+
+Location loc = op.getLoc();
+mlir::VectorType lhsType = op.getLhsType();
+mlir::VectorType rhsType = op.getRhsType();
+
+// For now handle LHS and RHS<8x[N]> - these are the types we
+// eventually expect from MMT4D. M and N dimensions must be even and at
+// least 2.
+if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+rhsType.getRank() != 2)
+  return failure();
+
+if (lhsType.isScalable() || !rhsType.isScalable())
+  return failure();
+
+// M, N, and K are the conventional names for matrix dimensions in the
+// context of matrix multiplication.
+auto M = lhsType.getDimSize(0);
+auto N = rhsType.getDimSize(0);
+auto K = rhsType.getDimSize(1);
+
+if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+N % 2 != 0 || !rhsType.getScalableDims()[0])
+  return failure();
+
+// Check permutation maps. For now only accept
+//   lhs: (d0, d1, d2) -> (d0, d2)
+//   rhs: (d0, d1, d2) -> (d1, d2)
+//   acc: (d0, d1, d2) -> (d0, d1)
+// Note: RHS is transposed.
+if (op.getIndexingMapsArray()[0] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+op.getIndexingMapsArray()[1] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+op.getIndexingMapsArray()[2] !=
+AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+  return failure();
+
+// Check iterator types for matrix multiplication.
+auto itTypes = op.getIteratorTypesArray();
+if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+itTypes[1] != vector::IteratorType

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-23 Thread Andrzej Warzyński via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+: public OpRewritePattern {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+PatternRewriter &rewriter) const override {
+
+Location loc = op.getLoc();
+mlir::VectorType lhsType = op.getLhsType();
+mlir::VectorType rhsType = op.getRhsType();
+
+// For now handle LHS and RHS<8x[N]> - these are the types we
+// eventually expect from MMT4D. M and N dimensions must be even and at

banach-space wrote:

[nit] We shouldn't be concerned with MMT4D in this dialect - it's a much 
higher-level abstraction and this logic should be valid irrespective of how the 
input is generated.

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-23 Thread Andrzej Warzyński via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the 
extension.
+template 
+inline std::enable_if_t<(std::is_base_of_v ||
+ std::is_base_of_v),
+std::optional>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null(v.getDefiningOp());
+  if (!extOp)
+return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+return {};
+
+  auto outTy = dyn_cast(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,  // smmla
+  Unsigned,// ummla
+  Mixed,   // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+return rewriter.create(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+// The accumulator comes transposed and the result will be transposed
+// later, so all we have to do here is swap the operands.
+return rewriter.create(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern

banach-space wrote:

It's a very long pattern. Could you document the high-level logic? 

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-23 Thread Andrzej Warzyński via llvm-branch-commits


@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.

banach-space wrote:

Could you add a note that `vector.contract` needs to be accompanied by 
`arith.extsi` (or `arith.extui`) Ops? Also, is I8MM the official name? 
Shouldn't that be FEAT_I8MM? 

Basically,  could we document a bit more?

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-15 Thread Momchil Velikov via llvm-branch-commits

momchil-velikov wrote:

> One high-level question - would sharing some code between NEON and SVE be 
> possible?

No, I can't see it happening and resulting in less, or simpler, or easier to 
maintain code.
However, it might be possible to add Neon lowering to this patch and  see if 
the result is any good.


https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-15 Thread Momchil Velikov via llvm-branch-commits

https://github.com/momchil-velikov updated 
https://github.com/llvm/llvm-project/pull/135636

>From 8e87a7f3b1438d9542d28c90eb9593ebe8cf6500 Mon Sep 17 00:00:00 2001
From: Momchil Velikov 
Date: Tue, 8 Apr 2025 14:43:54 +
Subject: [PATCH] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to
 SVE `*MMLA` instructions

---
 mlir/include/mlir/Conversion/Passes.td|   4 +
 .../Dialect/ArmSVE/Transforms/Transforms.h|   3 +
 .../Conversion/VectorToLLVM/CMakeLists.txt|   1 +
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   7 +
 .../LowerContractionToSMMLAPattern.cpp|   5 +-
 .../Dialect/ArmSVE/Transforms/CMakeLists.txt  |   1 +
 .../LowerContractionToSVEI8MMPattern.cpp  | 304 ++
 .../Vector/CPU/ArmSVE/vector-smmla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-summla.mlir  |  85 +
 .../Vector/CPU/ArmSVE/vector-ummla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-usmmla.mlir  |  95 ++
 .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir   | 117 +++
 .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir   | 159 +
 .../CPU/ArmSVE/contraction-summla-4x8x4.mlir  | 118 +++
 .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir   | 119 +++
 .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir  | 117 +++
 16 files changed, 1322 insertions(+), 1 deletion(-)
 create mode 100644 
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td 
b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : 
Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+Option<"armI8MM", "enable-arm-i8mm",
+   "bool", /*default=*/"false",
+   "Enables the use of Arm FEAT_I8MM instructions while lowering "
+   "the vector dialect.">,
 Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h 
b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
 const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt 
b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp 
b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
 populateVectorStepLoweringPatterns(pa

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-15 Thread Momchil Velikov via llvm-branch-commits

https://github.com/momchil-velikov updated 
https://github.com/llvm/llvm-project/pull/135636

>From 8e87a7f3b1438d9542d28c90eb9593ebe8cf6500 Mon Sep 17 00:00:00 2001
From: Momchil Velikov 
Date: Tue, 8 Apr 2025 14:43:54 +
Subject: [PATCH] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to
 SVE `*MMLA` instructions

---
 mlir/include/mlir/Conversion/Passes.td|   4 +
 .../Dialect/ArmSVE/Transforms/Transforms.h|   3 +
 .../Conversion/VectorToLLVM/CMakeLists.txt|   1 +
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   7 +
 .../LowerContractionToSMMLAPattern.cpp|   5 +-
 .../Dialect/ArmSVE/Transforms/CMakeLists.txt  |   1 +
 .../LowerContractionToSVEI8MMPattern.cpp  | 304 ++
 .../Vector/CPU/ArmSVE/vector-smmla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-summla.mlir  |  85 +
 .../Vector/CPU/ArmSVE/vector-ummla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-usmmla.mlir  |  95 ++
 .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir   | 117 +++
 .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir   | 159 +
 .../CPU/ArmSVE/contraction-summla-4x8x4.mlir  | 118 +++
 .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir   | 119 +++
 .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir  | 117 +++
 16 files changed, 1322 insertions(+), 1 deletion(-)
 create mode 100644 
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td 
b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : 
Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+Option<"armI8MM", "enable-arm-i8mm",
+   "bool", /*default=*/"false",
+   "Enables the use of Arm FEAT_I8MM instructions while lowering "
+   "the vector dialect.">,
 Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h 
b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
 const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt 
b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp 
b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
 populateVectorStepLoweringPatterns(pa

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-14 Thread Andrzej Warzyński via llvm-branch-commits

https://github.com/banach-space commented:

Thanks! This one is a bit longer, so I may need to wait till Thursday before I 
can review.

One high-level question - would sharing some code between NEON and SVE be 
possible?

https://github.com/llvm/llvm-project/pull/135636
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-14 Thread Momchil Velikov via llvm-branch-commits

https://github.com/momchil-velikov created 
https://github.com/llvm/llvm-project/pull/135636

Supersedes https://github.com/llvm/llvm-project/pull/135359

>From 2e61d3ee7b9ac88ae1be8ca248dad1a0880ccff4 Mon Sep 17 00:00:00 2001
From: Momchil Velikov 
Date: Tue, 8 Apr 2025 14:43:54 +
Subject: [PATCH] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to
 SVE `*MMLA` instructions

---
 mlir/include/mlir/Conversion/Passes.td|   4 +
 .../Dialect/ArmSVE/Transforms/Transforms.h|   3 +
 .../Conversion/VectorToLLVM/CMakeLists.txt|   1 +
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   7 +
 .../LowerContractionToSMMLAPattern.cpp|   5 +-
 .../Dialect/ArmSVE/Transforms/CMakeLists.txt  |   1 +
 .../LowerContractionToSVEI8MMPattern.cpp  | 304 ++
 .../Vector/CPU/ArmSVE/vector-smmla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-summla.mlir  |  85 +
 .../Vector/CPU/ArmSVE/vector-ummla.mlir   |  94 ++
 .../Vector/CPU/ArmSVE/vector-usmmla.mlir  |  95 ++
 .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir   | 117 +++
 .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir   | 159 +
 .../CPU/ArmSVE/contraction-summla-4x8x4.mlir  | 118 +++
 .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir   | 119 +++
 .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir  | 117 +++
 16 files changed, 1322 insertions(+), 1 deletion(-)
 create mode 100644 
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
 create mode 100644 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td 
b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : 
Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+Option<"armI8MM", "enable-arm-i8mm",
+   "bool", /*default=*/"false",
+   "Enables the use of Arm FEAT_I8MM instructions while lowering "
+   "the vector dialect.">,
 Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h 
b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
 const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt 
b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp 
b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-14 Thread via llvm-branch-commits

llvmbot wrote:




@llvm/pr-subscribers-mlir

Author: Momchil Velikov (momchil-velikov)


Changes

Supersedes https://github.com/llvm/llvm-project/pull/135359

---

Patch is 77.36 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/135636.diff


16 Files Affected:

- (modified) mlir/include/mlir/Conversion/Passes.td (+4) 
- (modified) mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h (+3) 
- (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (+1) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+7) 
- (modified) 
mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-1) 
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+1) 
- (added) 
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+304) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir (+94) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir (+85) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir (+94) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir (+95) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir 
(+117) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
 (+159) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir 
(+118) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir 
(+119) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir 
(+117) 


``diff
diff --git a/mlir/include/mlir/Conversion/Passes.td 
b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : 
Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+Option<"armI8MM", "enable-arm-i8mm",
+   "bool", /*default=*/"false",
+   "Enables the use of Arm FEAT_I8MM instructions while lowering "
+   "the vector dialect.">,
 Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h 
b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
 const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt 
b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp 
b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
 populateVectorStepLoweringPatterns(patterns);
 populateVectorRankReducingFMAPattern(patterns);
 populateVectorGatherLoweringPatterns(patterns);
+if (armI8MM) {
+  if (armNeon)
+arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+  if (armSVE)
+populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+}
 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git 
a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp 
b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..e807b

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-14 Thread via llvm-branch-commits

llvmbot wrote:




@llvm/pr-subscribers-mlir-neon

Author: Momchil Velikov (momchil-velikov)


Changes

Supersedes https://github.com/llvm/llvm-project/pull/135359

---

Patch is 77.36 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/135636.diff


16 Files Affected:

- (modified) mlir/include/mlir/Conversion/Passes.td (+4) 
- (modified) mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h (+3) 
- (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (+1) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+7) 
- (modified) 
mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-1) 
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+1) 
- (added) 
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+304) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir (+94) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir (+85) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir (+94) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir (+95) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir 
(+117) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
 (+159) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir 
(+118) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir 
(+119) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir 
(+117) 


``diff
diff --git a/mlir/include/mlir/Conversion/Passes.td 
b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : 
Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+Option<"armI8MM", "enable-arm-i8mm",
+   "bool", /*default=*/"false",
+   "Enables the use of Arm FEAT_I8MM instructions while lowering "
+   "the vector dialect.">,
 Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h 
b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
 const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt 
b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp 
b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
 populateVectorStepLoweringPatterns(patterns);
 populateVectorRankReducingFMAPattern(patterns);
 populateVectorGatherLoweringPatterns(patterns);
+if (armI8MM) {
+  if (armNeon)
+arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+  if (armSVE)
+populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+}
 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git 
a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp 
b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..

[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

2025-04-14 Thread via llvm-branch-commits

llvmbot wrote:




@llvm/pr-subscribers-mlir-sve

Author: Momchil Velikov (momchil-velikov)


Changes

Supersedes https://github.com/llvm/llvm-project/pull/135359

---

Patch is 77.36 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/135636.diff


16 Files Affected:

- (modified) mlir/include/mlir/Conversion/Passes.td (+4) 
- (modified) mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h (+3) 
- (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (+1) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+7) 
- (modified) 
mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-1) 
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+1) 
- (added) 
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+304) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir (+94) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir (+85) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir (+94) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir (+95) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir 
(+117) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
 (+159) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir 
(+118) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir 
(+119) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir 
(+117) 


``diff
diff --git a/mlir/include/mlir/Conversion/Passes.td 
b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : 
Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+Option<"armI8MM", "enable-arm-i8mm",
+   "bool", /*default=*/"false",
+   "Enables the use of Arm FEAT_I8MM instructions while lowering "
+   "the vector dialect.">,
 Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h 
b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
 const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt 
b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp 
b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
 populateVectorStepLoweringPatterns(patterns);
 populateVectorRankReducingFMAPattern(patterns);
 populateVectorGatherLoweringPatterns(patterns);
+if (armI8MM) {
+  if (armNeon)
+arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+  if (armSVE)
+populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+}
 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git 
a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp 
b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..e