Author: Sean Silva Date: 2020-12-15T12:50:56-08:00 New Revision: caf4f2e372a7a4d5d8b5a8733e44f002c6dee0d5
URL: https://github.com/llvm/llvm-project/commit/caf4f2e372a7a4d5d8b5a8733e44f002c6dee0d5 DIFF: https://github.com/llvm/llvm-project/commit/caf4f2e372a7a4d5d8b5a8733e44f002c6dee0d5.diff LOG: [mlir] Handle unknown ops in dynamic_tensor_from_elements bufferization Due to how the conversion infra works, the "clone" call that this pattern was using required all the cloned ops to be immediately legalized as part of this dialect conversion invocation. That was previously working due to a couple factors: - In the test case, there was scf.if, which we happen to mark as legal as part of marking the entire SCF dialect as legal for the scf.parallel we generate here. - Originally, this test case had std.extract_element in the body, which we happened to have a pattern for in this pass. After I migrated that to `tensor.extract` (which removed the tensor.extract bufferization from here), I hacked this up to use `std.dim` which we still have patterns for in this pass. This patch updates the test case to use a truly opaque op `test.source` that properly stresses this aspect of the pattern. (this also removes a stray dependency on the `tensor` dialect that I must have left behind as part of my hacking this pass up when migrating to `tensor.extract`) Differential Revision: https://reviews.llvm.org/D93262 Added: Modified: mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp mlir/test/Dialect/Standard/bufferize.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp index 6691355d232c..a84934b0ebb8 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -15,7 +15,6 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" @@ -70,18 +69,29 @@ class BufferizeDynamicTensorFromElementsOp upperBounds.push_back(upperBound); } - // Generate tensor elements with a parallel loop. - rewriter.create<scf::ParallelOp>( - loc, lowerBounds, upperBounds, steps, - [&](OpBuilder &b, Location loc, ValueRange ivs) { - BlockAndValueMapping mapping; - mapping.map(op.body().getArguments(), ivs); - for (auto &nestedOp : op.getBody()->without_terminator()) - b.clone(nestedOp, mapping); - auto yieldOp = cast<YieldOp>(op.getBody()->getTerminator()); - b.create<StoreOp>(loc, mapping.lookup(yieldOp.value()), result, ivs); - b.create<scf::YieldOp>(loc); - }); + // Generate tensor elements with a parallel loop that stores into + // each element of the resulting memref. + // + // This is a bit tricky. We cannot simply clone the ops because when an op + // is cloned, it must be legalized. However, we want to allow arbitrary ops + // in the body that we don't necessarily have legalization patterns for as + // part of this dialect conversion invocation. + // + // To accomplish this, we use mergeBlockBefore to "move" this op's body + // into the scf.parallel's body. + auto parallel = + rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); + Block *parallelBody = parallel.getBody(); + rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), + parallelBody->getArguments()); + // Replace the inlined yield op with a store op. The scf.parallel's builder + // already populated an scf.yield at the end, so we don't need to worry + // about creating that. + Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); + rewriter.setInsertionPointAfter(elementYield); + rewriter.replaceOpWithNewOp<StoreOp>(elementYield, + elementYield->getOperands()[0], result, + parallelBody->getArguments()); rewriter.replaceOp(op, {result}); return success(); @@ -168,7 +178,6 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> { target.addLegalDialect<StandardOpsDialect>(); target.addLegalDialect<scf::SCFDialect>(); - target.addLegalDialect<tensor::TensorDialect>(); populateStdBufferizePatterns(context, typeConverter, patterns); target.addIllegalOp<DynamicTensorFromElementsOp, TensorCastOp, diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir index 75ff2a9d78f0..8ae10ccf0f3b 100644 --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -123,20 +123,20 @@ func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> { return %0 : tensor<2xindex> } -// The dynamic_tensor_from_elements op clones each op in its body. -// Make sure that regions nested within such ops are recursively converted. -// CHECK-LABEL: func @recursively_convert_cloned_regions -func @recursively_convert_cloned_regions(%arg0: tensor<*xf32>, %arg1: index, %arg2: i1) -> tensor<?xindex> { - %tensor = dynamic_tensor_from_elements %arg1 { +// The dynamic_tensor_from_elements op needs to put its body into the +// resulting scf.parallel. To handle unknown ops in the body, it cannot clone +// the body because that would require the cloned ops to be legalized +// immediately, which is usually not possible since they might be from various +// other dialects. +// +// CHECK-LABEL: func @unknown_ops_in_body +func @unknown_ops_in_body(%arg0: index) -> tensor<?xindex> { + // CHECK-NOT: dynamic_tensor_from_elements + %tensor = dynamic_tensor_from_elements %arg0 { ^bb0(%iv: index): - %48 = scf.if %arg2 -> (index) { - scf.yield %iv : index - } else { - // CHECK-NOT: dim{{.*}}tensor - %50 = dim %arg0, %iv : tensor<*xf32> - scf.yield %50 : index - } - yield %48 : index + // CHECK: test.source + %0 = "test.source"() : () -> index + yield %0 : index } : tensor<?xindex> return %tensor : tensor<?xindex> } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits