llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) <details> <summary>Changes</summary> - [mlir][vector][nfc] Update comments in vector-transpose.mlir - [mlir][ArmSME] Remove `ConvertIllegalShapeCastOpsToTransposes` --- Full diff: https://github.com/llvm/llvm-project/pull/139706.diff 4 Files Affected: - (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (-54) - (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-12) - (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (-45) - (modified) mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir (+37-3) ``````````diff diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 95965872f4098..51750f0bb9694 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory } }; -/// A rewrite to turn unit dim transpose-like vector.shape_casts into -/// vector.transposes. The shape_cast has to be from an illegal vector type to a -/// legal one (as defined by isLegalVectorType). -/// -/// The reasoning for this is if we've got to this pass and we still have -/// shape_casts of illegal types, then they likely will not cancel out. Turning -/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to -/// eliminate them. -/// -/// Example: -/// -/// BEFORE: -/// ```mlir -/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32> -/// ``` -/// -/// AFTER: -/// ```mlir -/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> -/// ``` -struct ConvertIllegalShapeCastOpsToTransposes - : public OpRewritePattern<vector::ShapeCastOp> { - using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, - PatternRewriter &rewriter) const override { - auto sourceType = shapeCastOp.getSourceVectorType(); - auto resultType = shapeCastOp.getResultVectorType(); - if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) - return rewriter.notifyMatchFailure(shapeCastOp, - kMatchFailureNotIllegalToLegal); - - // Note: If we know that `sourceType` is an illegal vector type (and 2D) - // then dim 0 is scalable and dim 1 is fixed. - if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1) - return rewriter.notifyMatchFailure( - shapeCastOp, "expected source to be a 2D scalable vector with a " - "trailing unit dim"); - - auto loc = shapeCastOp.getLoc(); - auto transpose = rewriter.create<vector::TransposeOp>( - loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0}); - - if (resultType.getRank() == 1) - rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType, - transpose); - else - rewriter.replaceOp(shapeCastOp, transpose); - - return success(); - } -}; - /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use /// the ZA state. This workaround rewrite to support these transposes when ZA is /// available. @@ -943,7 +890,6 @@ struct VectorLegalizationPass RewritePatternSet rewritePatterns(context); rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks, LiftIllegalVectorTransposeToMemory, - ConvertIllegalShapeCastOpsToTransposes, LowerIllegalTransposeStoreViaZA>(context); if (failed( applyPatternsGreedily(getOperation(), std::move(rewritePatterns)))) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index f6c3c6a61afb6..83a287d29d773 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5617,18 +5617,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { // shape_cast(transpose(x)) -> shape_cast(x) if (auto transpose = getSource().getDefiningOp<TransposeOp>()) { - // This folder does - // shape_cast(transpose) -> shape_cast - // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does - // shape_cast -> shape_cast(transpose) - // i.e. the complete opposite. When paired, these 2 patterns can cause - // infinite cycles in pattern rewriting. - // ConvertIllegalShapeCastOpsToTransposes only matches on scalable - // vectors, so by disabling this folder for scalable vectors the - // cycle is avoided. - // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is - // still needed. If it's not, then we can fold here. - if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) { + if (isOrderPreserving(transpose)) { setOperand(transpose.getVector()); return getResult(); } diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir index d56df9814f173..6e6615c243d2a 100644 --- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir +++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir @@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v // ----- -// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d( -// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>) -func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> { - // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> - %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32> - return %0 : vector<1x[4]xf32> -} - -// ----- - -// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d( -// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>) -func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> { - // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> - // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32> - %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32> - return %0 : vector<[4]xf32> -} - -// ----- - -// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory -func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> { - // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32> - // CHECK-NOT: vector.shape_cast - %pad = arith.constant 0.0 : f32 - %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32> - %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32> - return %cast : vector<1x[4]xf32> -} - -// ----- - -// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory -func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> { - // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32> - // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32> - %pad = arith.constant 0.0 : f32 - %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32> - %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32> - return %cast : vector<[4]xf32> -} - -// ----- - // CHECK-LABEL: @multi_tile_splat func.func @multi_tile_splat() -> vector<[8]x[8]xi32> { diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir index e47578bc80719..625b4a9c53e42 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir @@ -161,6 +161,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8 // ----- +// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows: +// 1 -> 0 +// 2 -> 4 +// Because 0 < 4, this permutation is order preserving and effectively a shape_cast. +// (same as the example above, but one of the dims is scalable) +// CHECK-LABEL: @transpose_shape_cast_scalable +// CHECK-SAME: %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> { +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : +// CHECK-SAME: vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<[4]x4xi8> +func.func @transpose_shape_cast_scalable(%arg : vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> { + %0 = vector.transpose %arg, [1, 0, 3, 4, 2] + : vector<1x[4]x4x1x1xi8> to vector<[4]x1x1x1x4xi8> + %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4xi8> to vector<[4]x4xi8> + return %1 : vector<[4]x4xi8> +} + +// ----- + // In this test, the mapping of non-unit dimensions (1 and 2) is as follows: // 1 -> 2 // 2 -> 1 @@ -225,11 +244,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi // ----- -// Scalable dimensions should be treated as non-unit dimensions. -// CHECK-LABEL: @transpose_of_shape_cast_scalable +// CHECK-LABEL: @transpose_shape_cast_scalable +// CHECK-SAME: %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> { +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : +// CHECK-SAME: vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<[6]x1x1xi8> +func.func @shape_cast_transpose_scalable(%arg : vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> { + %0 = vector.shape_cast %arg : vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8> + %1 = vector.transpose %0, [0, 2, 1] + : vector<[6]x1x1xi8> to vector<[6]x1x1xi8> + return %1 : vector<[6]x1x1xi8> +} + +// ----- + +// Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions +// (hence no folding). +// CHECK-LABEL: @negative_shape_cast_transpose_scalable_unit // CHECK: vector.shape_cast // CHECK: vector.transpose -func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> { +func.func @negative_shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> { %0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8> %1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8> return %1 : vector<4x[1]xi8> `````````` </details> https://github.com/llvm/llvm-project/pull/139706 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits