llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-linalg Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> The pass used to access erased operations and block arguments in the type converter. That is no longer supported in the new conversion driver. --- Full diff: https://github.com/llvm/llvm-project/pull/152912.diff 2 Files Affected: - (modified) mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp (+32-2) - (modified) mlir/test/Dialect/Linalg/detensorize_0d.mlir (+4-3) ``````````diff diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 830905495e759..221f95a8d8f33 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -458,6 +458,22 @@ struct LinalgDetensorize } }; + /// A listener that forwards notifyBlockErased and notifyOperationErased to + /// the given callbacks. + struct CallbackListener : public RewriterBase::Listener { + CallbackListener(std::function<void(Operation *op)> onOperationErased, + std::function<void(Block *block)> onBlockErased) + : onOperationErased(onOperationErased), onBlockErased(onBlockErased) {} + + void notifyBlockErased(Block *block) override { onBlockErased(block); } + void notifyOperationErased(Operation *op) override { + onOperationErased(op); + } + + std::function<void(Operation *op)> onOperationErased; + std::function<void(Block *block)> onBlockErased; + }; + void runOnOperation() override { MLIRContext *context = &getContext(); DetensorizeTypeConverter typeConverter; @@ -551,8 +567,22 @@ struct LinalgDetensorize populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, shouldConvertBranchOperand); - if (failed( - applyFullConversion(getOperation(), target, std::move(patterns)))) + ConversionConfig config; + auto onOperationErased = [&](Operation *op) { + opsToDetensor.erase(op); + detensorableBranchOps.erase(op); + }; + auto onBlockErased = [&](Block *block) { + for (BlockArgument arg : block->getArguments()) { + blockArgsToDetensor.erase(arg); + } + }; + CallbackListener listener(onOperationErased, onBlockErased); + + config.listener = &listener; + config.allowPatternRollback = false; + if (failed(applyFullConversion(getOperation(), target, std::move(patterns), + config))) signalPassFailure(); RewritePatternSet canonPatterns(context); diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir index 74931cb0830bc..5c29b04630cad 100644 --- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir @@ -53,10 +53,11 @@ func.func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tenso } // CHECK-LABEL: func @detensor_op_sequence // CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>) -// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg1_val_1:.*]] = tensor.extract %[[arg1]] // CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] -// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]] -// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]] +// CHECK-DAG: %[[arg1_val_2:.*]] = tensor.extract %[[arg1]] +// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val_2]], %[[arg2_val]] +// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val_1]], %[[detensored_res]] // CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] // CHECK: return %[[new_tensor_res]] `````````` </details> https://github.com/llvm/llvm-project/pull/152912 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits