https://github.com/llvmbot created https://github.com/llvm/llvm-project/pull/85942
Backport 0597644a6466ae9148b0b41cb8f95d5022e045c2 47bc565ca7990a2de20af4030baf08ac62739aca Requested by: @lhunloh >From d5933a73516f3bdfc37216d52278e0ca3d42859d Mon Sep 17 00:00:00 2001 From: Congcong Cai <congcongcai0...@163.com> Date: Tue, 5 Mar 2024 03:58:12 +0800 Subject: [PATCH 1/2] [mlir][transform] replace original op to loop ops (#83537) (cherry picked from commit 0597644a6466ae9148b0b41cb8f95d5022e045c2) --- mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 1 + .../TilingInterface/lower-to-loops-using-interface.mlir | 1 + 2 files changed, 2 insertions(+) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 140bdd1f2db361..be875297fc93ca 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2092,6 +2092,7 @@ DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne( scf::lowerToLoopsUsingSCFForOp(rewriter, target); if (failed(loops)) return emitDefaultDefiniteFailure(target); + rewriter.eraseOp(target); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir index 7969de0d456bb6..1b2c553b25ded0 100644 --- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir @@ -33,6 +33,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]] // CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]] // CHECK: memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK-NOT: linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) // ----- >From 6db21232a2f1d47c61fb3bf5985487ee695491f3 Mon Sep 17 00:00:00 2001 From: lhunloh <8047408+lhun...@users.noreply.github.com> Date: Wed, 6 Mar 2024 22:07:30 +0000 Subject: [PATCH 2/2] [MLIR] [Transforms] Let `transform.structured.convert_to_loops` return handles to loops (#83984) This lets `transform.structured.convert_to_loops` return handles to the generated loops, making this transformation more useful to use for (transformation-)nesting purposes. This is modelled after SCFs `transform.loop.forall_to_for` which returns handles to loops. Introduced in commit aa2a96a24ae3a8cc04635ab6ede474c5f2665053, with a note that they might move out of the `Linalg`-Dialect, but no reason given for the non-return of handles. As far as I can see, this transform always returns loops. (cherry picked from commit 47bc565ca7990a2de20af4030baf08ac62739aca) --- .../Linalg/TransformOps/LinalgTransformOps.td | 22 +++--- .../TransformOps/LinalgTransformOps.cpp | 35 ++++++--- .../lower-to-loops-using-interface.mlir | 75 +++++++++++++++++-- 3 files changed, 101 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index b139f1ef58b3a9..da7183dae75ffc 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1274,33 +1274,29 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize", }]; } +//===----------------------------------------------------------------------===// +// ConvertToLoopsOp +//===----------------------------------------------------------------------===// + def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops", [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, - TransformOpInterface, TransformEachOpTrait, + DeclareOpInterfaceMethods<TransformOpInterface>, ReportTrackingListenerFailuresOpTrait]> { let description = [{ For operations that implement the `TilingInterface`, and implement the `generateScalarImplementation` method, lowers the operation to - loops. This operation does not return any handles. + loops. The return handle points to all generated loops. + Fails if the payload ops cannot be lowered to loops. }]; let arguments = (ins TransformHandleTypeInterface:$target); - let results = (outs); + let results = (outs TransformHandleTypeInterface:$result); let assemblyFormat = [{ - $target attr-dict `:` type($target) - }]; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::transform::TransformRewriter &rewriter, - ::mlir::TilingInterface target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); + $target attr-dict `:` functional-type(operands, results) }]; } - //===----------------------------------------------------------------------===// // DecomposeInterfaceOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index be875297fc93ca..905875ae43ce8a 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2083,16 +2083,31 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter, // ConvertToLoopsOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne( - transform::TransformRewriter &rewriter, TilingInterface target, - transform::ApplyToEachResultList &results, - transform::TransformState &state) { - rewriter.setInsertionPoint(target); - FailureOr<SmallVector<scf::ForOp>> loops = - scf::lowerToLoopsUsingSCFForOp(rewriter, target); - if (failed(loops)) - return emitDefaultDefiniteFailure(target); - rewriter.eraseOp(target); +DiagnosedSilenceableFailure +transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + SmallVector<Operation *> loops; + for (Operation *target : state.getPayloadOps(getTarget())) { + auto tilingOp = dyn_cast<TilingInterface>(*target); + if (!target) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "expected the payload to implement TilingInterface"; + diag.attachNote(target->getLoc()) << "payload op"; + return diag; + } + rewriter.setInsertionPoint(target); + FailureOr<SmallVector<scf::ForOp>> generatedLoops = + scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp); + if (failed(generatedLoops)) + return emitDefaultDefiniteFailure(target); + for (scf::ForOp &loop : *generatedLoops) { + loops.push_back(loop.getOperation()); + } + rewriter.eraseOp(target); + } + results.set(cast<OpResult>(getResult()), loops); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir index 1b2c553b25ded0..8cbee3cbb758b2 100644 --- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir @@ -11,7 +11,8 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.convert_to_loops %matmul : !transform.any_op + %0 = transform.structured.convert_to_loops %matmul + : (!transform.any_op) -> (!transform.any_op) transform.yield } } @@ -37,6 +38,57 @@ module attributes {transform.with_named_sequence} { // ----- +func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>, + %arg2 : memref<?x?xf32>, %arg3 : memref<?xf32>, %arg4 : memref<?xf32>) { + linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) + outs(%arg2 : memref<?x?xf32>) + linalg.matvec ins(%arg0, %arg3 : memref<?x?xf32>, memref<?xf32>) + outs(%arg4 : memref<?xf32>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %linalg_ops = transform.structured.match interface{TilingInterface} in %arg1 + : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.convert_to_loops %linalg_ops + : (!transform.any_op) -> (!transform.any_op) + %1:5 = transform.split_handle %0 + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-LABEL: func @gemm +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32> +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: memref<?xf32> +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: memref<?xf32> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]] +// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]] +// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]] +// CHECK-DAG: %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV2]]] +// CHECK-DAG: %[[RHS:.+]] = memref.load %[[ARG1]][%[[IV2]], %[[IV1]]] +// CHECK-DAG: %[[OUT:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]] +// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]] +// CHECK: memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]] +// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]] +// CHECK-DAG: %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV3]], %[[IV4]]] +// CHECK-DAG: %[[RHS:.+]] = memref.load %[[ARG3]][%[[IV4]]] +// CHECK-DAG: %[[OUT:.+]] = memref.load %[[ARG4]][%[[IV3]]] +// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]] +// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]] +// CHECK: memref.store %[[ADDF]], %[[ARG4]][%[[IV3]]] + +// ----- + func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>, %arg2 : memref<200xi8>, %arg3 : memref<300x200xi64>) { linalg.generic { @@ -66,7 +118,8 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.convert_to_loops %generic : !transform.any_op + %0 = transform.structured.convert_to_loops %generic + : (!transform.any_op) -> (!transform.any_op) transform.yield } } @@ -111,7 +164,8 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.convert_to_loops %conv : !transform.any_op + %0 = transform.structured.convert_to_loops %conv + : (!transform.any_op) -> (!transform.any_op) transform.yield } } @@ -165,7 +219,8 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.convert_to_loops %pool : !transform.any_op + %0 = transform.structured.convert_to_loops %pool + : (!transform.any_op) -> (!transform.any_op) transform.yield } } @@ -216,7 +271,8 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %map = transform.structured.match ops{["linalg.map"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.convert_to_loops %map : !transform.any_op + %0 = transform.structured.convert_to_loops %map + : (!transform.any_op) -> (!transform.any_op) transform.yield } } @@ -248,7 +304,8 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.convert_to_loops %transpose : !transform.any_op + %0 = transform.structured.convert_to_loops %transpose + : (!transform.any_op) -> (!transform.any_op) transform.yield } } @@ -285,7 +342,8 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.convert_to_loops %reduce : !transform.any_op + %0 = transform.structured.convert_to_loops %reduce + : (!transform.any_op) -> (!transform.any_op) transform.yield } } @@ -322,7 +380,8 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.convert_to_loops %broadcast : !transform.any_op + %0 = transform.structured.convert_to_loops %broadcast + : (!transform.any_op) -> (!transform.any_op) transform.yield } } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits