https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/114017
Draft. Do not review yet. Depends on #113999. >From e88daefdd87df823bbfbe34cb44cb9ef00bd4e62 Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Tue, 29 Oct 2024 09:51:11 +0100 Subject: [PATCH] [mlir][bufferization] Add support for non-unique `func.return` --- .../FuncBufferizableOpInterfaceImpl.cpp | 75 +++----- .../Transforms/OneShotModuleBufferize.cpp | 179 +++++++++++++----- .../one-shot-module-bufferize-invalid.mlir | 22 +-- .../Transforms/one-shot-module-bufferize.mlir | 24 +++ 4 files changed, 190 insertions(+), 110 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index a372e87d8335f1..2d9cca0c5816e7 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -41,18 +41,13 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { #endif // NDEBUG } -/// Return the unique ReturnOp that terminates `funcOp`. -/// Return nullptr if there is no such unique ReturnOp. -static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { - func::ReturnOp returnOp; - for (Block &b : funcOp.getBody()) { - if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { - if (returnOp) - return nullptr; - returnOp = candidateOp; - } - } - return returnOp; +/// Return all top-level func.return ops in the given function. +static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) { + SmallVector<func::ReturnOp> result; + for (Block &b : funcOp.getBody()) + if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator())) + result.push_back(returnOp); + return result; } /// Return the index-th bufferized function argument type. This assumes that the @@ -372,15 +367,6 @@ struct FuncOpInterface getBufferType(op, value, options, invocationStack); } - LogicalResult verifyAnalysis(Operation *op, - const AnalysisState &state) const { - auto funcOp = cast<func::FuncOp>(op); - // TODO: func.func with multiple returns are not supported. - if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal()) - return op->emitOpError("op without unique func.return is not supported"); - return success(); - } - /// Rewrite function bbArgs and return values into buffer form. This function /// bufferizes the function signature and the ReturnOp. When the entire /// function body has been bufferized, function return types can be switched @@ -427,41 +413,38 @@ struct FuncOpInterface return success(); } - // TODO: Support functions with multiple returns. - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); - assert(returnOp->getNumOperands() == retTypes.size() && - "incorrect number of return values"); - Location loc = returnOp.getLoc(); - // 1. Bufferize every block. for (Block &block : funcOp.getBody()) if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options))) return failure(); - // 2. Bufferize all operands of the return op. - SmallVector<Value> returnValues; - for (auto [returnVal, bufferizedType] : - llvm::zip_equal(returnOp->getOperands(), retTypes)) { - auto tensorType = dyn_cast<TensorType>(returnVal.getType()); - rewriter.setInsertionPoint(returnOp); - - // If not a tensor type just forward it. - if (!tensorType) { - returnValues.push_back(returnVal); - continue; + // 2. Bufferize the operands of the all return op. + for (func::ReturnOp returnOp : getReturnOps(funcOp)) { + assert(returnOp->getNumOperands() == retTypes.size() && + "incorrect number of return values"); + SmallVector<Value> returnValues; + for (auto [returnVal, bufferizedType] : + llvm::zip_equal(returnOp->getOperands(), retTypes)) { + auto tensorType = dyn_cast<TensorType>(returnVal.getType()); + rewriter.setInsertionPoint(returnOp); + + // If not a tensor type just forward it. + if (!tensorType) { + returnValues.push_back(returnVal); + continue; + } + + // Note: If `inferFunctionResultLayout = true`, casts are later folded + // away. + Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( + returnOp.getLoc(), bufferizedType, returnVal); + returnValues.push_back(toMemrefOp); } - // Note: If `inferFunctionResultLayout = true`, casts are later folded - // away. - Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( - loc, bufferizedType, returnVal); - returnValues.push_back(toMemrefOp); + returnOp.getOperandsMutable().assign(returnValues); } - returnOp.getOperandsMutable().assign(returnValues); - // 3. Set the new function type. funcOp.setType(newFuncType); return success(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 0a4072605c265f..e4635ebd78d8f8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -86,18 +86,13 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) { return state.addExtension<FuncAnalysisState>(); } -/// Return the unique ReturnOp that terminates `funcOp`. -/// Return nullptr if there is no such unique ReturnOp. -static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { - func::ReturnOp returnOp; - for (Block &b : funcOp.getBody()) { - if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { - if (returnOp) - return nullptr; - returnOp = candidateOp; - } - } - return returnOp; +/// Return all top-level func.return ops in the given function. +static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) { + SmallVector<func::ReturnOp> result; + for (Block &b : funcOp.getBody()) + if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator())) + result.push_back(returnOp); + return result; } namespace { @@ -146,24 +141,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, return success(); } - // Support only single return-terminated block in the function. - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); - - for (OpOperand &returnVal : returnOp->getOpOperands()) - if (isa<RankedTensorType>(returnVal.get().getType())) - for (BlockArgument bbArg : funcOp.getArguments()) - if (isa<RankedTensorType>(bbArg.getType())) { - int64_t returnIdx = returnVal.getOperandNumber(); - int64_t bbArgIdx = bbArg.getArgNumber(); - if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { - funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; - if (state.getOptions().testAnalysisOnly) - annotateEquivalentReturnBbArg(returnVal, bbArg); + // Find all func.return ops. + SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp); + assert(!returnOps.empty() && "expected at least one ReturnOp"); + + // Build alias sets. Merge all aliases from all func.return ops. + for (BlockArgument bbArg : funcOp.getArguments()) { + if (isa<RankedTensorType>(bbArg.getType())) { + int64_t bbArgIdx = bbArg.getArgNumber(); + // Store aliases in a set, so that we don't add the same alias twice. + SetVector<int64_t> aliases; + for (func::ReturnOp returnOp : returnOps) { + for (OpOperand &returnVal : returnOp->getOpOperands()) { + if (isa<RankedTensorType>(returnVal.get().getType())) { + int64_t returnIdx = returnVal.getOperandNumber(); + if (state.areAliasingBufferizedValues(returnVal.get(), bbArg)) + aliases.insert(returnIdx); } - if (state.areAliasingBufferizedValues(returnVal.get(), bbArg)) - funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); } + } + for (int64_t alias : aliases) + funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias); + } + } + + // Build equivalence sets. + // Helper function that finds an equivalent block argument index for the + // given OpOperand. Return std::nullopt if no equivalent block argument could + // be found. + auto findEquivalentBlockArgIdx = + [&](OpOperand &opOperand) -> std::optional<int64_t> { + Value v = opOperand.get(); + if (!isa<TensorType>(v.getType())) + return std::nullopt; + for (BlockArgument bbArg : funcOp.getArguments()) { + if (isa<RankedTensorType>(bbArg.getType())) { + if (state.areEquivalentBufferizedValues(v, bbArg)) { + if (state.getOptions().testAnalysisOnly) + annotateEquivalentReturnBbArg(opOperand, bbArg); + return bbArg.getArgNumber(); + } + } + } + return std::nullopt; + }; + + int64_t numResults = returnOps.front()->getNumOperands(); + for (int64_t i = 0; i < numResults; ++i) { + // Find the equivalent block argument index for the i-th operand of the + // first func.return op. + std::optional<int64_t> maybeEquiv = + findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i)); + if (!maybeEquiv.has_value()) + continue; + int64_t bbArgIdx = *maybeEquiv; + bool allEquiv = true; + + // Check if all other func.return ops have the same equivalent block + // argument for the i-th operand. In contrast to aliasing information, + // which is just "merged", equivalence information must match across all + // func.return ops. + for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) { + std::optional<int64_t> maybeEquiv = + findEquivalentBlockArgIdx(returnOp->getOpOperand(i)); + if (maybeEquiv != bbArgIdx) { + allEquiv = false; + break; + } + } + + // All func.return ops have the same equivalent block argument for the i-th + // operand. + if (allEquiv) + funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx; + } return success(); } @@ -299,14 +350,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, // For each FuncOp, the number of func::CallOp it contains. DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { - if (!funcOp.getBody().empty()) { - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - if (!returnOp) - return funcOp->emitError() - << "cannot bufferize a FuncOp with tensors and " - "without a unique ReturnOp"; - } - // Collect function calls and populate the caller map. numberCallOpsContainedInFuncOp[funcOp] = 0; return funcOp.walk([&](func::CallOp callOp) -> WalkResult { @@ -342,6 +385,42 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, return success(); } +/// Helper function that extracts the source from a memref.cast. If the given +/// value is not a memref.cast result, simply returns the given value. +static Value unpackCast(Value v) { + auto castOp = v.getDefiningOp<memref::CastOp>(); + if (!castOp) + return v; + return castOp.getSource(); +} + +/// Helper function that returns the return types (skipping casts) of the given +/// func.return ops. This function returns as many types as the return ops have +/// operands. If the i-th operand is not the same for all func.return ops, then +/// the i-th returned type is an "empty" type. +static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) { + assert(!returnOps.empty() && "expected at least one ReturnOp"); + int numOperands = returnOps.front()->getNumOperands(); + + // Helper function that unpacks memref.cast ops and returns the type. + auto getSourceType = [&](Value v) { return unpackCast(v).getType(); }; + + SmallVector<Type> result; + for (int i = 0; i < numOperands; ++i) { + // Get the type of the i-th operand of the first func.return ops. + Type t = getSourceType(returnOps.front()->getOperand(i)); + + // Check if all other func.return ops have a matching operand type. + for (int j = 1; j < static_cast<int>(returnOps.size()); ++j) + if (getSourceType(returnOps[j]->getOperand(i)) != t) + t = Type(); + + result.push_back(t); + } + + return result; +} + /// Fold return values that are memref casts and update function return types. /// /// During FuncOp bufferization, the exact type of the returned memrefs (if any) @@ -350,21 +429,33 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, /// entire function body, a more concise memref type can potentially be used for /// the return type of the function. static void foldMemRefCasts(func::FuncOp funcOp) { + // There is nothing to do for bodiless ops. if (funcOp.getBody().empty()) return; - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - SmallVector<Type> resultTypes; + // Compute the common result types of all return ops. + SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp); + SmallVector<Type> resultTypes = getReturnTypes(returnOps); - for (OpOperand &operand : returnOp->getOpOperands()) { - if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) { - operand.set(castOp.getSource()); - resultTypes.push_back(castOp.getSource().getType()); - } else { - resultTypes.push_back(operand.get().getType()); + // Remove direct casts. + for (func::ReturnOp returnOp : returnOps) { + for (OpOperand &operand : returnOp->getOpOperands()) { + // Bail if no common result type was found. + if (resultTypes[operand.getOperandNumber()]) { + operand.set(unpackCast(operand.get())); + } } } + // Fill in the missing result types that were not the same among all + // func.return ops. + for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) { + if (resultTypes[i]) + continue; + resultTypes[i] = funcOp.getFunctionType().getResult(i); + } + + // Update the function type. auto newFuncType = FunctionType::get( funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); funcOp.setType(newFuncType); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir index 2829eafb7c1c59..f3da82cc0064d4 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -1,24 +1,5 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics -// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}} -func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>) - -> (tensor<f32>, tensor<f32>) -{ - cf.cond_br %cond1, ^bb1, ^bb2 - - ^bb1: - %T:2 = scf.if %cond2 -> (tensor<f32>, tensor<f32>) { - scf.yield %t1, %t2 : tensor<f32>, tensor<f32> - } else { - scf.yield %t2, %t1 : tensor<f32>, tensor<f32> - } - return %T#0, %T#1 : tensor<f32>, tensor<f32> - ^bb2: - return %t2, %t1 : tensor<f32>, tensor<f32> -} - -// ----- - // expected-error @-3 {{expected callgraph to be free of circular dependencies}} func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> { @@ -160,7 +141,8 @@ func.func @regression_scf_while() { // ----- -// expected-error @below{{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}} +// expected-error @below{{could not infer buffer type of block argument}} +// expected-error @below{{failed to bufferize op}} func.func @func_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> { func.return %t : tensor<5xf32> ^bb1(%arg1 : tensor<5xf32>): diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index d31b43477beb9f..4f10ffea561aa8 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -722,3 +722,27 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> { %0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>) return %0 : memref<5xf32> } + +// ----- + +// The two func.return operands have different types after bufferization. Make +// sure that memref.cast ops are inserted. + +// CHECK-LABEL: func @result_type_mismatch({{.*}}) -> memref<5xf32, strided<[?], offset: ?>> +func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> { + // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<10xf32> + %t = tensor.empty() : tensor<10xf32> + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + // CHECK: %[[m0:.*]] = memref.subview %[[alloc]][0] [5] [2] : memref<10xf32> to memref<5xf32, strided<[2]>> + // CHECK: %[[cast0:.*]] = memref.cast %[[m0]] : memref<5xf32, strided<[2]>> to memref<5xf32, strided<[?], offset: ?>> + %0 = tensor.extract_slice %t[0][5][2] : tensor<10xf32> to tensor<5xf32> + // CHECK: return %[[cast0]] : memref<5xf32, strided<[?], offset: ?> + return %0 : tensor<5xf32> +^bb2: + // CHECK: %[[m1:.*]] = memref.subview %[[alloc]][2] [5] [1] : memref<10xf32> to memref<5xf32, strided<[1], offset: 2>> + // CHECK: %[[cast1:.*]] = memref.cast %[[m1]] : memref<5xf32, strided<[1], offset: 2>> to memref<5xf32, strided<[?], offset: ?>> + %1 = tensor.extract_slice %t[2][5][1] : tensor<10xf32> to tensor<5xf32> + // CHECK: return %[[cast1]] : memref<5xf32, strided<[?], offset: ?>> + return %1 : tensor<5xf32> +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits