https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/140572
>From df54d59d29e8afc04740e86281bce6be5dd157da Mon Sep 17 00:00:00 2001 From: Momchil Velikov <momchil.veli...@arm.com> Date: Fri, 16 May 2025 15:47:36 +0000 Subject: [PATCH 1/2] [MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op --- .../mlir/Dialect/ArmSVE/CMakeLists.txt | 1 + .../TransformOps/ArmSVEVectorTransformOps.h | 31 +++ .../TransformOps/ArmSVEVectorTransformOps.td | 26 ++ .../ArmSVE/TransformOps/CMakeLists.txt | 6 + mlir/include/mlir/InitAllExtensions.h | 2 + mlir/lib/Dialect/ArmSVE/CMakeLists.txt | 1 + .../TransformOps/ArmSVEVectorTransformOps.cpp | 54 ++++ .../ArmSVE/TransformOps/CMakeLists.txt | 19 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 263 ++++++++++-------- .../Vector/CPU/ArmSVE/vector-summla.mlir | 123 ++++---- .../Vector/CPU/ArmSVE/vector-ummla.mlir | 138 +++++---- .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 138 +++++---- 12 files changed, 512 insertions(+), 290 deletions(-) create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt create mode 100644 mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp create mode 100644 mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h new file mode 100644 index 0000000000000..7f22cd1fe6435 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h @@ -0,0 +1,31 @@ +//===- ArmSVEVectorTransformOps.h - Vector transform ops --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H +#define MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===----------------------------------------------------------------------===// +// ArmSVE Vector Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace arm_sve { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace arm_sve +} // namespace mlir + +#endif // MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td new file mode 100644 index 0000000000000..81b59340f3b0d --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td @@ -0,0 +1,26 @@ +//===- ArmSVEVectorTransformOps.td - Arm SVE transform ops--*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef ARMSVE_VECTOR_TRANSFORM_OPS +#define ARMSVE_VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" + +def ApplyArmSVELowerContractionPatternsOp + : Op<Transform_Dialect, "apply_patterns.vector.arm_sve.lower_contraction", + [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { + let description = [{ + Indicates that vector contraction-like operations should be lowered to + finer-grained vector primitives using the ArmSVE dialect. + }]; + + let assemblyFormat = "attr-dict"; +} + +#endif // ARMSVE_VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..ce8d8fea7f188 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td) +mlir_tablegen(ArmSVEVectorTransformOps.h.inc -gen-op-decls) +mlir_tablegen(ArmSVEVectorTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRArmSVEVectorTransformOpsIncGen) + +add_mlir_doc(ArmSVEVectorTransformOps ArmSVEVectorTransformOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 1e3f7c649a8bd..419dffcf9113b 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -33,6 +33,7 @@ #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" @@ -104,6 +105,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) { transform::registerLoopExtension(registry); transform::registerPDLExtension(registry); vector::registerTransformDialectExtension(registry); + arm_sve::registerTransformDialectExtension(registry); // Translation extensions need to be registered by calling // `registerAllToLLVMIRTranslations` (see All.h). diff --git a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp new file mode 100644 index 0000000000000..b2ca4fc1eaa8c --- /dev/null +++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp @@ -0,0 +1,54 @@ +//===- ArmSVEVectorTransformOps.cpp - Implementation transform ops -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" + +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Apply...PatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class ArmSVEVectorTransformDialectExtension + : public transform::TransformDialectExtension< + ArmSVEVectorTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + ArmSVEVectorTransformDialectExtension) + + ArmSVEVectorTransformDialectExtension() { + declareGeneratedDialect<arm_sve::ArmSVEDialect>(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc" + +void mlir::arm_sve::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions<ArmSVEVectorTransformDialectExtension>(); +} diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..8771826e08913 --- /dev/null +++ b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_dialect_library(MLIRArmSVEVectorTransformOps + ArmSVEVectorTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE/TransformOps + + DEPENDS + MLIRArmSVEVectorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRVectorDialect + MLIRTransformDialect + MLIRArmSVEDialect + MLIRArmSVETransforms + ) + \ No newline at end of file diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir index af0cb37e2d249..3991038761e8d 100644 --- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' | FileCheck %s +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s #attrs = { indexing_maps = [ @@ -12,77 +12,82 @@ // CHECK-LABEL: @test_vector_contract_to_smmla +// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8> +// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8> +// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32> + +// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32> +// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32> +// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8> + // Extract LHS rows 0 and 1, concatenate, turn into scalable vector -// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> -// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8> +// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8> -// Replicate across the entire length of the scalabale vector -// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> +// Replicate across the entire length of the scalable vector +// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8> -// Same for LHS rows 2 and 4 -// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> -// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8> -// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> +// Same for LHS rows 2 and 3 +// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8> +// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8> // Extract sub-tiles from the RHS -// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> -// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> -// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8> // Extract accumulator rows 0 and 1 and pack (into "registers") -// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> -// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32> -// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32> -// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32> - -// Same for accumulator rows 2 and 3. -// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> -// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32> -// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32> -// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T18:[0-9]+]] = vector.bitcast %[[T16]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T19:[0-9]+]] = vector.bitcast %[[T17]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64> +// CHECK-NEXT: %[[T21:[0-9]+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3 +// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T26:[0-9]+]] = vector.bitcast %[[T24]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T27:[0-9]+]] = vector.bitcast %[[T25]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64> +// CHECK-NEXT: %[[T29:[0-9]+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xi32> from vector<[8]xi32> // Do the sub-tile matrix multiplications -// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> - -// Unpack (from "registers") and insert in the output result rows 0 and 1 -// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64> -// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.smmla %[[ACC_0]], %[[LHS_0]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.smmla %[[ACC_1]], %[[LHS_0]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.smmla %[[ACC_2]], %[[LHS_1]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.smmla %[[ACC_3]], %[[LHS_1]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T38:[0-9]+]] = vector.bitcast %[[T37]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64> +// CHECK-NEXT: %[[UNPACK_RES_0:[0-9]+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[UNPACK_RES_1:[0-9]+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32> // Same for result rows 2 and 3 -// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64> -// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>> - +// CHECK-NEXT: %[[T43:[0-9]+]] = vector.scalable.insert %[[PACK_RES_10]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T44:[0-9]+]] = vector.scalable.insert %[[PACK_RES_11]], %[[T43]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T45:[0-9]+]] = vector.bitcast %[[T44]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %res1_0, %res2_1 = vector.deinterleave %[[T45]] : vector<[4]xi64> -> vector<[2]xi64> +// CHECK-NEXT: %[[UNPACK_RES_2:[0-9]+]] = vector.bitcast %res1_0 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[UNPACK_RES_3:[0-9]+]] = vector.bitcast %res2_1 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_2:[0-9]+]] = vector.insert %[[UNPACK_RES_2]], %[[TMP_OUT_1]] [2] : vector<[4]xi32> into vector<4x[4]xi32> +// CHECK-NEXT: %[[OUT:[0-9]+]] = vector.insert %[[UNPACK_RES_3]], %[[TMP_OUT_2]] [3] : vector<[4]xi32> into vector<4x[4]xi32> + +// CHECK-NEXT: return %[[OUT]] : vector<4x[4]xi32> func.func @test_vector_contract_to_smmla(%lhs: vector<4x8xi8>, %rhs: vector<[4]x8xi8>, %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> { @@ -97,76 +102,82 @@ func.func @test_vector_contract_to_smmla(%lhs: vector<4x8xi8>, // CHECK-LABEL: @test_vector_contract_to_smmla_implicit_sext +// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8> +// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8> +// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32> + +// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32> +// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32> +// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8> + // Extract LHS rows 0 and 1, concatenate, turn into scalable vector -// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> -// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8> +// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8> -// Replicate across the entire length of the scalabale vector -// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> +// Replicate across the entire length of the scalable vector +// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8> -// Same for LHS rows 2 and 4 -// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> -// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8> -// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> +// Same for LHS rows 2 and 3 +// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8> +// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8> // Extract sub-tiles from the RHS -// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> -// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> -// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8> // Extract accumulator rows 0 and 1 and pack (into "registers") -// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> -// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32> -// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32> -// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32> - -// Same for accumulator rows 2 and 3. -// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> -// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32> -// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32> -// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T18:[0-9]+]] = vector.bitcast %[[T16]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T19:[0-9]+]] = vector.bitcast %[[T17]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64> +// CHECK-NEXT: %[[T21:[0-9]+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3 +// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T26:[0-9]+]] = vector.bitcast %[[T24]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T27:[0-9]+]] = vector.bitcast %[[T25]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64> +// CHECK-NEXT: %[[T29:[0-9]+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xi32> from vector<[8]xi32> // Do the sub-tile matrix multiplications -// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> - -// Unpack (from "registers") and insert in the output result rows 0 and 1 -// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64> -// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.smmla %[[ACC_0]], %[[LHS_0]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.smmla %[[ACC_1]], %[[LHS_0]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.smmla %[[ACC_2]], %[[LHS_1]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.smmla %[[ACC_3]], %[[LHS_1]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T38:[0-9]+]] = vector.bitcast %[[T37]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64> +// CHECK-NEXT: %[[UNPACK_RES_0:[0-9]+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[UNPACK_RES_1:[0-9]+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32> // Same for result rows 2 and 3 -// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64> -// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T43:[0-9]+]] = vector.scalable.insert %[[PACK_RES_10]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T44:[0-9]+]] = vector.scalable.insert %[[PACK_RES_11]], %[[T43]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T45:[0-9]+]] = vector.bitcast %[[T44]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %res1_0, %res2_1 = vector.deinterleave %[[T45]] : vector<[4]xi64> -> vector<[2]xi64> +// CHECK-NEXT: %[[UNPACK_RES_2:[0-9]+]] = vector.bitcast %res1_0 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[UNPACK_RES_3:[0-9]+]] = vector.bitcast %res2_1 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_2:[0-9]+]] = vector.insert %[[UNPACK_RES_2]], %[[TMP_OUT_1]] [2] : vector<[4]xi32> into vector<4x[4]xi32> +// CHECK-NEXT: %[[OUT:[0-9]+]] = vector.insert %[[UNPACK_RES_3]], %[[TMP_OUT_2]] [3] : vector<[4]xi32> into vector<4x[4]xi32> + +// CHECK-NEXT: return %[[OUT]] : vector<4x[4]xi32> // Test a variant where the sign-extension of the operands is // implicit. The output is identical to the one of the previous test. @@ -179,3 +190,15 @@ func.func @test_vector_contract_to_smmla_implicit_sext(%lhs: vector<4x8xi8>, return %0 : vector<4x[4]xi32> } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> + + transform.apply_patterns to %func { + transform.apply_patterns.vector.arm_sve.lower_contraction + } : !transform.op<"func.func"> + + transform.yield + } +} diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir index b6285d068b0f8..ffcb437f189f9 100644 --- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s +// RUN: mlir-opt %s -transform-interpreter | FileCheck %s #packed_maps = [ affine_map<(d0, d1, d2) -> (d0, d2)>, @@ -8,66 +8,77 @@ // CHECK-LABEL: @test_vector_contract_to_usmmla_rev -// Extract LHS rows 0 and 1, concatenate, turn into scalable vector -// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T1:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T1]][1] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> -// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T5:[0-9]+]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8> +// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8> +// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32> + +// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32> +// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32> +// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8> -// Replicate across the entire length of the scalabale vector -// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> +// Extract LHS rows 0 and 1, concatenate, turn into scalable vector +// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8> +// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8> -// Same for LHS rows 2 and 4 -// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T1]][3] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> -// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T5]][0] : vector<16xi8> into vector<[16]xi8> -// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> +// Replicate across the entire length of the scalable vector +// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8> +// Same for LHS rows 2 and 3 +// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8> +// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8> // Extract sub-tiles from the RHS -// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> -// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> -// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8> // Extract accumulator rows 0 and 1 and pack (into "registers") -// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T0:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T0]][1] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T21:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T19]], %[[T20]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32> -// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.intr.vector.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32> -// CHECK-NEXT: %[[T23:[0-9]+]] = llvm.intr.vector.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32> - -// Same for accumulator rows 2 and 3. -// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.extractvalue %[[T0]][2] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.extractvalue %[[T0]][3] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T26:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T24]], %[[T25]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32> -// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.intr.vector.extract %[[T26]][0] : vector<[4]xi32> from vector<[8]xi32> -// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.intr.vector.extract %[[T26]][4] : vector<[4]xi32> from vector<[8]xi32> +// Note the lack of bitcasts (i.e. interleave by pairs). This has the effect of transposing the tile. +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T16]], %[[T17]] : vector<[4]xi32> -> vector<[8]xi32> +// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T20]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T20]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3 +// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T24]], %[[T25]] : vector<[4]xi32> -> vector<[8]xi32> +// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T28]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T28]][4] : vector<[4]xi32> from vector<[8]xi32> // Do the sub-tile matrix multiplications -// CHECK-NEXT: %[[T29:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T22]], %[[T17]], %[[T10]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T30:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T23]], %[[T18]], %[[T10]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T31:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T27]], %[[T17]], %[[T15]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T32:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T28]], %[[T18]], %[[T15]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> - -// Unpack (from "registers") and insert in the output result rows 0 and 1 -// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.insert %[[T29]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.insert %[[T30]], %[[T33]][4] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T35:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T34]]) : (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> -// CHECK-NEXT: %[[T36:[0-9]+]] = llvm.extractvalue %[[T35]][0] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> -// CHECK-NEXT: %[[T37:[0-9]+]] = llvm.extractvalue %[[T35]][1] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> -// CHECK-NEXT: %[[T38:[0-9]+]] = llvm.insertvalue %[[T36]], %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.insertvalue %[[T37]], %[[T38]][1] : !llvm.array<4 x vector<[4]xi32>> +// The accumulator is transposed, the operands are in the opposite order, +// thus the result is obtained transposed too. +// +// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.usmmla %[[ACC_0]], %[[RHS_0]], %[[LHS_0]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.usmmla %[[ACC_1]], %[[RHS_1]], %[[LHS_0]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.usmmla %[[ACC_2]], %[[RHS_0]], %[[LHS_1]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.usmmla %[[ACC_3]], %[[RHS_1]], %[[LHS_1]] : vector<[16]xi8> to vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// Again, no bitcast/interleave-by-pairs, i.e. the transposed result of sub-tile +// multiplications is unpacked into the correct output layout. +// +// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[UNPACK_RES_0:[_a-z0-9]+]], %[[UNPACK_RES_1:[_a-z0-9]+]] = vector.deinterleave %[[T37]] : vector<[8]xi32> -> vector<[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32> // Same for result rows 2 and 3 -// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T31]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.intr.vector.insert %[[T32]], %[[T40]][4] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> -// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> -// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> -// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.insertvalue %[[T43]], %[[T39]][2] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.insertvalue %[[T44]], %[[T45]][3] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T47:[0-9]+]] = builtin.unrealized_conversion_cast %[[T46]] : !llvm.array<4 x vector<[4]xi32>> to vector<4x[4]xi32> +// CHECK-NEXT: %[[T43:[0-9]+]] = vector.scalable.insert %[[PACK_RES_10]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T44:[0-9]+]] = vector.scalable.insert %[[PACK_RES_11]], %[[T43]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[UNPACK_RES_2:[_a-z0-9]+]], %[[UNPACK_RES_3:[_a-z0-9]+]] = vector.deinterleave %[[T44]] : vector<[8]xi32> -> vector<[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_2:[0-9]+]] = vector.insert %[[UNPACK_RES_2]], %[[TMP_OUT_1]] [2] : vector<[4]xi32> into vector<4x[4]xi32> +// CHECK-NEXT: %[[OUT:[0-9]+]] = vector.insert %[[UNPACK_RES_3]], %[[TMP_OUT_2]] [3] : vector<[4]xi32> into vector<4x[4]xi32> + +// CHECK-NEXT: return %[[OUT]] : vector<4x[4]xi32> func.func @test_vector_contract_to_usmmla_rev( %lhs: vector<4x8xi8>, @@ -83,3 +94,15 @@ func.func @test_vector_contract_to_usmmla_rev( return %2 : vector<4x[4]xi32> } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> + + transform.apply_patterns to %func { + transform.apply_patterns.vector.arm_sve.lower_contraction + } : !transform.op<"func.func"> + + transform.yield + } +} diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir index cde57842295f7..9f63a0c1688f5 100644 --- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s +// RUN: mlir-opt %s -transform-interpreter --split-input-file | FileCheck %s #packed_maps = [ affine_map<(d0, d1, d2) -> (d0, d2)>, @@ -8,76 +8,82 @@ // CHECK-LABEL: @test_vector_contract_to_ummla +// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8> +// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8> +// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32> + +// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32> +// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32> +// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8> + // Extract LHS rows 0 and 1, concatenate, turn into scalable vector -// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> -// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8> +// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8> -// Replicate across the entire length of the scalabale vector -// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> +// Replicate across the entire length of the scalable vector +// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8> -// Same for LHS rows 2 and 4 -// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> -// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8> -// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> +// Same for LHS rows 2 and 3 +// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8> +// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8> // Extract sub-tiles from the RHS -// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> -// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> -// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8> // Extract accumulator rows 0 and 1 and pack (into "registers") -// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> -// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32> -// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32> -// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32> - -// Same for accumulator rows 2 and 3. -// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> -// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32> -// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32> -// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T18:[0-9]+]] = vector.bitcast %[[T16]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T19:[0-9]+]] = vector.bitcast %[[T17]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64> +// CHECK-NEXT: %[[T21:[0-9]+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3 +// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T26:[0-9]+]] = vector.bitcast %[[T24]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T27:[0-9]+]] = vector.bitcast %[[T25]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64> +// CHECK-NEXT: %[[T29:[0-9]+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xi32> from vector<[8]xi32> // Do the sub-tile matrix multiplications -// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.ummla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.ummla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.ummla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.ummla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> - -// Unpack (from "registers") and insert in the output result rows 0 and 1 -// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64> -// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.ummla %[[ACC_0]], %[[LHS_0]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.ummla %[[ACC_1]], %[[LHS_0]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.ummla %[[ACC_2]], %[[LHS_1]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.ummla %[[ACC_3]], %[[LHS_1]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T38:[0-9]+]] = vector.bitcast %[[T37]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64> +// CHECK-NEXT: %[[UNPACK_RES_0:[0-9]+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[UNPACK_RES_1:[0-9]+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32> // Same for result rows 2 and 3 -// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64> -// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T43:[0-9]+]] = vector.scalable.insert %[[PACK_RES_10]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T44:[0-9]+]] = vector.scalable.insert %[[PACK_RES_11]], %[[T43]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T45:[0-9]+]] = vector.bitcast %[[T44]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %res1_0, %res2_1 = vector.deinterleave %[[T45]] : vector<[4]xi64> -> vector<[2]xi64> +// CHECK-NEXT: %[[UNPACK_RES_2:[0-9]+]] = vector.bitcast %res1_0 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[UNPACK_RES_3:[0-9]+]] = vector.bitcast %res2_1 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_2:[0-9]+]] = vector.insert %[[UNPACK_RES_2]], %[[TMP_OUT_1]] [2] : vector<[4]xi32> into vector<4x[4]xi32> +// CHECK-NEXT: %[[OUT:[0-9]+]] = vector.insert %[[UNPACK_RES_3]], %[[TMP_OUT_2]] [3] : vector<[4]xi32> into vector<4x[4]xi32> + +// CHECK-NEXT: return %[[OUT]] : vector<4x[4]xi32> func.func @test_vector_contract_to_ummla(%lhs: vector<4x8xi8>, %rhs: vector<[4]x8xi8>, @@ -92,3 +98,15 @@ func.func @test_vector_contract_to_ummla(%lhs: vector<4x8xi8>, return %2 : vector<4x[4]xi32> } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> + + transform.apply_patterns to %func { + transform.apply_patterns.vector.arm_sve.lower_contraction + } : !transform.op<"func.func"> + + transform.yield + } +} diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir index d0eef9fb9769c..f9d8b604150e6 100644 --- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s +// RUN: mlir-opt %s -transform-interpreter | FileCheck %s #packed_maps = [ affine_map<(d0, d1, d2) -> (d0, d2)>, @@ -8,76 +8,82 @@ // CHECK-LABEL: @test_vector_contract_to_usmmla +// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8> +// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8> +// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32> + +// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32> +// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32> +// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8> + // Extract LHS rows 0 and 1, concatenate, turn into scalable vector -// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> -// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8> +// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8> -// Replicate across the entire length of the scalabale vector -// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> +// Replicate across the entire length of the scalable vector +// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8> -// Same for LHS rows 2 and 4 -// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> -// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8> -// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> +// Same for LHS rows 2 and 3 +// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8> +// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8> // Extract sub-tiles from the RHS -// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> -// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> -// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8> // Extract accumulator rows 0 and 1 and pack (into "registers") -// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> -// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32> -// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32> -// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32> - -// Same for accumulator rows 2 and 3. -// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> -// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32> -// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32> -// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T18:[0-9]+]] = vector.bitcast %[[T16]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T19:[0-9]+]] = vector.bitcast %[[T17]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64> +// CHECK-NEXT: %[[T21:[0-9]+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3 +// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T26:[0-9]+]] = vector.bitcast %[[T24]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T27:[0-9]+]] = vector.bitcast %[[T25]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64> +// CHECK-NEXT: %[[T29:[0-9]+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xi32> from vector<[8]xi32> // Do the sub-tile matrix multiplications -// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> - -// Unpack (from "registers") and insert in the output result rows 0 and 1 -// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64> -// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.usmmla %[[ACC_0]], %[[LHS_0]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.usmmla %[[ACC_1]], %[[LHS_0]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.usmmla %[[ACC_2]], %[[LHS_1]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.usmmla %[[ACC_3]], %[[LHS_1]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T38:[0-9]+]] = vector.bitcast %[[T37]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64> +// CHECK-NEXT: %[[UNPACK_RES_0:[0-9]+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[UNPACK_RES_1:[0-9]+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32> // Same for result rows 2 and 3 -// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64> -// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T43:[0-9]+]] = vector.scalable.insert %[[PACK_RES_10]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T44:[0-9]+]] = vector.scalable.insert %[[PACK_RES_11]], %[[T43]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T45:[0-9]+]] = vector.bitcast %[[T44]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %res1_0, %res2_1 = vector.deinterleave %[[T45]] : vector<[4]xi64> -> vector<[2]xi64> +// CHECK-NEXT: %[[UNPACK_RES_2:[0-9]+]] = vector.bitcast %res1_0 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[UNPACK_RES_3:[0-9]+]] = vector.bitcast %res2_1 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_2:[0-9]+]] = vector.insert %[[UNPACK_RES_2]], %[[TMP_OUT_1]] [2] : vector<[4]xi32> into vector<4x[4]xi32> +// CHECK-NEXT: %[[OUT:[0-9]+]] = vector.insert %[[UNPACK_RES_3]], %[[TMP_OUT_2]] [3] : vector<[4]xi32> into vector<4x[4]xi32> + +// CHECK-NEXT: return %[[OUT]] : vector<4x[4]xi32> func.func @test_vector_contract_to_usmmla( %lhs: vector<4x8xi8>, @@ -93,3 +99,15 @@ func.func @test_vector_contract_to_usmmla( return %2 : vector<4x[4]xi32> } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> + + transform.apply_patterns to %func { + transform.apply_patterns.vector.arm_sve.lower_contraction + } : !transform.op<"func.func"> + + transform.yield + } +} \ No newline at end of file >From 4ab77651e49bbd5d5e699ae4e282824a407a0a66 Mon Sep 17 00:00:00 2001 From: Momchil Velikov <momchil.veli...@arm.com> Date: Tue, 27 May 2025 15:56:43 +0000 Subject: [PATCH 2/2] [fixup] Rename the operation --- .../Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td | 2 +- mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir | 4 ++-- mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir | 4 ++-- mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir | 4 ++-- mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir | 6 +++--- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td index 81b59340f3b0d..53784982be6dc 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td @@ -13,7 +13,7 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" def ApplyArmSVELowerContractionPatternsOp - : Op<Transform_Dialect, "apply_patterns.vector.arm_sve.lower_contraction", + : Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_i8mm", [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { let description = [{ Indicates that vector contraction-like operations should be lowered to diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir index 3991038761e8d..404ef87380c9c 100644 --- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir @@ -194,9 +194,9 @@ func.func @test_vector_contract_to_smmla_implicit_sext(%lhs: vector<4x8xi8>, module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> - + transform.apply_patterns to %func { - transform.apply_patterns.vector.arm_sve.lower_contraction + transform.apply_patterns.arm_sve.vector_contract_to_i8mm } : !transform.op<"func.func"> transform.yield diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir index ffcb437f189f9..518e995b698ad 100644 --- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir @@ -98,9 +98,9 @@ func.func @test_vector_contract_to_usmmla_rev( module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> - + transform.apply_patterns to %func { - transform.apply_patterns.vector.arm_sve.lower_contraction + transform.apply_patterns.arm_sve.vector_contract_to_i8mm } : !transform.op<"func.func"> transform.yield diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir index 9f63a0c1688f5..703182490f660 100644 --- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir @@ -102,9 +102,9 @@ func.func @test_vector_contract_to_ummla(%lhs: vector<4x8xi8>, module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> - + transform.apply_patterns to %func { - transform.apply_patterns.vector.arm_sve.lower_contraction + transform.apply_patterns.arm_sve.vector_contract_to_i8mm } : !transform.op<"func.func"> transform.yield diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir index f9d8b604150e6..28156f6ef2700 100644 --- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir @@ -103,11 +103,11 @@ func.func @test_vector_contract_to_usmmla( module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> - + transform.apply_patterns to %func { - transform.apply_patterns.vector.arm_sve.lower_contraction + transform.apply_patterns.arm_sve.vector_contract_to_i8mm } : !transform.op<"func.func"> transform.yield } -} \ No newline at end of file +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits