llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-vector Author: Momchil Velikov (momchil-velikov) <details> <summary>Changes</summary> --- Patch is 73.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140572.diff 12 Files Affected: - (modified) mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt (+1) - (added) mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h (+31) - (added) mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td (+26) - (added) mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt (+6) - (modified) mlir/include/mlir/InitAllExtensions.h (+2) - (modified) mlir/lib/Dialect/ArmSVE/CMakeLists.txt (+1) - (added) mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp (+54) - (added) mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt (+19) - (modified) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir (+143-120) - (modified) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir (+73-50) - (modified) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir (+78-60) - (modified) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir (+78-60) ``````````diff diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h new file mode 100644 index 0000000000000..7f22cd1fe6435 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h @@ -0,0 +1,31 @@ +//===- ArmSVEVectorTransformOps.h - Vector transform ops --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H +#define MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===----------------------------------------------------------------------===// +// ArmSVE Vector Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace arm_sve { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace arm_sve +} // namespace mlir + +#endif // MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td new file mode 100644 index 0000000000000..81b59340f3b0d --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td @@ -0,0 +1,26 @@ +//===- ArmSVEVectorTransformOps.td - Arm SVE transform ops--*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef ARMSVE_VECTOR_TRANSFORM_OPS +#define ARMSVE_VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" + +def ApplyArmSVELowerContractionPatternsOp + : Op<Transform_Dialect, "apply_patterns.vector.arm_sve.lower_contraction", + [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { + let description = [{ + Indicates that vector contraction-like operations should be lowered to + finer-grained vector primitives using the ArmSVE dialect. + }]; + + let assemblyFormat = "attr-dict"; +} + +#endif // ARMSVE_VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..ce8d8fea7f188 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td) +mlir_tablegen(ArmSVEVectorTransformOps.h.inc -gen-op-decls) +mlir_tablegen(ArmSVEVectorTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRArmSVEVectorTransformOpsIncGen) + +add_mlir_doc(ArmSVEVectorTransformOps ArmSVEVectorTransformOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 37e4904cb48ed..767c7099accbb 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -34,6 +34,7 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" @@ -106,6 +107,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) { transform::registerLoopExtension(registry); transform::registerPDLExtension(registry); vector::registerTransformDialectExtension(registry); + arm_sve::registerTransformDialectExtension(registry); // Translation extensions need to be registered by calling // `registerAllToLLVMIRTranslations` (see All.h). diff --git a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp new file mode 100644 index 0000000000000..b2ca4fc1eaa8c --- /dev/null +++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp @@ -0,0 +1,54 @@ +//===- ArmSVEVectorTransformOps.cpp - Implementation transform ops -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" + +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Apply...PatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class ArmSVEVectorTransformDialectExtension + : public transform::TransformDialectExtension< + ArmSVEVectorTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + ArmSVEVectorTransformDialectExtension) + + ArmSVEVectorTransformDialectExtension() { + declareGeneratedDialect<arm_sve::ArmSVEDialect>(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc" + +void mlir::arm_sve::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions<ArmSVEVectorTransformDialectExtension>(); +} diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..8771826e08913 --- /dev/null +++ b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_dialect_library(MLIRArmSVEVectorTransformOps + ArmSVEVectorTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE/TransformOps + + DEPENDS + MLIRArmSVEVectorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRVectorDialect + MLIRTransformDialect + MLIRArmSVEDialect + MLIRArmSVETransforms + ) + \ No newline at end of file diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir index af0cb37e2d249..3991038761e8d 100644 --- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' | FileCheck %s +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s #attrs = { indexing_maps = [ @@ -12,77 +12,82 @@ // CHECK-LABEL: @test_vector_contract_to_smmla +// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8> +// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8> +// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32> + +// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32> +// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32> +// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8> + // Extract LHS rows 0 and 1, concatenate, turn into scalable vector -// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> -// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8> +// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8> -// Replicate across the entire length of the scalabale vector -// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> +// Replicate across the entire length of the scalable vector +// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8> -// Same for LHS rows 2 and 4 -// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>> -// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> -// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8> -// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> +// Same for LHS rows 2 and 3 +// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8> +// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8> +// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8> // Extract sub-tiles from the RHS -// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> -// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> -// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8> // Extract accumulator rows 0 and 1 and pack (into "registers") -// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> -// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32> -// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32> -// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32> - -// Same for accumulator rows 2 and 3. -// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64> -// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> -// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32> -// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32> -// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T18:[0-9]+]] = vector.bitcast %[[T16]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T19:[0-9]+]] = vector.bitcast %[[T17]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64> +// CHECK-NEXT: %[[T21:[0-9]+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3 +// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32> +// CHECK-NEXT: %[[T26:[0-9]+]] = vector.bitcast %[[T24]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T27:[0-9]+]] = vector.bitcast %[[T25]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64> +// CHECK-NEXT: %[[T29:[0-9]+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xi32> from vector<[8]xi32> // Do the sub-tile matrix multiplications -// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> -// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> - -// Unpack (from "registers") and insert in the output result rows 0 and 1 -// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64> -// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.smmla %[[ACC_0]], %[[LHS_0]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.smmla %[[ACC_1]], %[[LHS_0]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.smmla %[[ACC_2]], %[[LHS_1]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32> +// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.smmla %[[ACC_3]], %[[LHS_1]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T38:[0-9]+]] = vector.bitcast %[[T37]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64> +// CHECK-NEXT: %[[UNPACK_RES_0:[0-9]+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[UNPACK_RES_1:[0-9]+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32> +// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32> // Same for result rows 2 and 3 -// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32> -// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64> -// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> -// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32> -// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>> -// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>> - +// CHECK-NEXT: %[[T43... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/140572 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits