Author: Abhishek Varma Date: 2024-09-30T14:51:23+05:30 New Revision: b3cdd66549a17e8ab83b23117d0a1fc9feb50534
URL: https://github.com/llvm/llvm-project/commit/b3cdd66549a17e8ab83b23117d0a1fc9feb50534 DIFF: https://github.com/llvm/llvm-project/commit/b3cdd66549a17e8ab83b23117d0a1fc9feb50534.diff LOG: Revert "[MLIR][TilingInterface] Extend consumer fusion for multi-use of produ…" This reverts commit b8c974f09391d78035928c599a911009bbe49e85. Added: Modified: mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 50cfd29e6bf907..7cfd772a72b175 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1481,29 +1481,21 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { /// failure otherwise. static FailureOr<OpOperand *> getConsumerFromUses(Value val, Block *containingOpBlock) { - // Check that the value has exactly one use which isn't a scf.yield or a - // tensor.parallel_insert_slice op. - OpOperand *operand = nullptr; - for (OpOperand &opOperand : val.getUses()) { - Operation *consumerOp = opOperand.getOwner(); - if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp)) - continue; - if (operand) - return failure(); - // TODO: We have to init result of consumer before scf.for, use - // DestinationStyleOpInterface to get result shape from init for now. - // Add support for other op such as op has InferTypeOpInterface. - if (!isa<TilingInterface>(consumerOp) || - !isa<DestinationStyleOpInterface>(consumerOp)) - return failure(); - if (containingOpBlock != consumerOp->getBlock()) - return failure(); - operand = &opOperand; - } - - if (operand) - return operand; - return failure(); + // Step 1. Check that the value has exactly one use. + if (!llvm::hasSingleElement(val.getUses())) + return failure(); + // Step 2. Get uses. + OpOperand &operand = (*val.getUses().begin()); + Operation *consumerOp = operand.getOwner(); + // TODO: We have to init result of consumer before scf.for, use + // DestinationStyleOpInterface to get result shape from init for now. + // Add support for other op such as op has InferTypeOpInterface. + if (!isa<TilingInterface>(consumerOp) || + !isa<DestinationStyleOpInterface>(consumerOp)) + return failure(); + if (containingOpBlock != consumerOp->getBlock()) + return failure(); + return &operand; } /// Find the perfectly nested loops outside of given loop(included) sorted from diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index f5f703d95e2d5b..fdefdcc453ae7a 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -437,74 +437,3 @@ module attributes {transform.with_named_sequence} { // CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 : // CHECK: } // CHECK: return %[[LOOP_RESULT1]]#1 : - -// ----- - -// This test case checks fusion of consumer even if the producer has multiple uses. -// The multiple uses of the producer essentially means that besides the consumer -// op in concern, the only other uses of the producer are allowed in :- -// 1. scf.yield -// 2. tensor.parallel_insert_slice - -module { - module { - func.func @fuse_consumer_for_multi_use_producer(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) { - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - %c256 = arith.constant 256 : index - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<256x256xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> - %2:2 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %1, %arg5 = %arg2) -> (tensor<256x256xf32>, tensor<256x256xf32>) { - %3 = scf.for %arg6 = %c0 to %c256 step %c64 iter_args(%arg7 = %arg4) -> (tensor<256x256xf32>) { - %extracted_slice = tensor.extract_slice %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32> - %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32> - %extracted_slice_1 = tensor.extract_slice %arg1[0, %arg6] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32> - %5 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice : tensor<64x64xf32>) -> tensor<64x64xf32> - %inserted_slice = tensor.insert_slice %5 into %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32> - scf.yield %inserted_slice : tensor<256x256xf32> - } - %4 = linalg.add ins(%3, %arg5 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> - scf.yield %3, %4 : tensor<256x256xf32>, tensor<256x256xf32> - } - return %2#0, %2#1 : tensor<256x256xf32>, tensor<256x256xf32> - } - } - module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } - } -} -// CHECK: func.func @fuse_consumer_for_multi_use_producer( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32> -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32> -// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32> -// CHECK: %[[dest1:.*]] = linalg.fill -// CHECK-SAME: outs(%[[dest0]] : -// CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]] -// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[ARG2]]) -// CHECK-SAME: { -// CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]] -// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[dest0]]) -// CHECK-SAME: { -// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1] -// CHECK: %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1] -// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul -// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : -// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG1]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add -// CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] : -// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] : -// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] : -// CHECK: } -// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 : -// CHECK: } -// CHECK: return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 : _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits