Author: Michele Scuttari Date: 2025-05-22T09:12:32+02:00 New Revision: 06eb7d7fe399ab4a5fbd289afa35f12ba008685f
URL: https://github.com/llvm/llvm-project/commit/06eb7d7fe399ab4a5fbd289afa35f12ba008685f DIFF: https://github.com/llvm/llvm-project/commit/06eb7d7fe399ab4a5fbd289afa35f12ba008685f.diff LOG: Revert "[MLIR] Add bufferization state class to OneShotBufferization pass (#1…" This reverts commit 67fc1660d987d145a68a3c3dfbccfbe4b91fba59. Added: Modified: mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 43c97d57e1834..cb6ef8bc17220 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -578,20 +578,6 @@ class AnalysisState { insideMutuallyExclusiveRegionsCache; }; -/// BufferizationState provides information about the state of the IR during the -/// bufferization process. -class BufferizationState { -public: - /// Get a reference to the collection of cached symbol tables. - SymbolTableCollection &getSymbolTables(); - -private: - /// The cached symbol tables. - /// The user is expected to update / invalidate the cached symbol tables if - /// the bufferized operation has the Symbol or SymbolTable traits. - SymbolTableCollection symbolTables; -}; - /// Create an AllocTensorOp for the given shaped value (memref or tensor). /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with /// undefined contents is allocated. diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td index b599a9f053215..95022d7d665d2 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -426,8 +426,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { /*retType=*/"::llvm::LogicalResult", /*methodName=*/"bufferize", /*args=*/(ins "::mlir::RewriterBase &":$rewriter, - "const ::mlir::bufferization::BufferizationOptions &":$options, - "::mlir::bufferization::BufferizationState &":$state), + "const ::mlir::bufferization::BufferizationOptions &":$options), /*methodBody=*/"", /*defaultImplementation=*/[{ llvm_unreachable("bufferize not implemented"); diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index dafa4b9b183f2..7a1a701bea6dc 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -93,8 +93,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor", let extraClassDeclaration = [{ LogicalResult bufferize(RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state); + const BufferizationOptions &options); bool resultBufferizesToMemoryWrite(OpResult opResult, const AnalysisState &state); @@ -283,8 +282,7 @@ def Bufferization_MaterializeInDestinationOp let extraClassDeclaration = [{ LogicalResult bufferize(RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state); + const BufferizationOptions &options); bool bufferizesToMemoryRead(OpOperand &opOperand, const AnalysisState &state); @@ -377,8 +375,7 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor", } LogicalResult bufferize(RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state); + const BufferizationOptions &options); }]; } @@ -461,8 +458,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ //===------------------------------------------------------------------===// LogicalResult bufferize(RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { // to_tensor/to_buffer pairs fold away after bufferization. return success(); } @@ -554,8 +550,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [ } LogicalResult bufferize(RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state); + const BufferizationOptions &options); }]; let assemblyFormat = [{ diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h index c08bd6c436133..e5f3b6d571f43 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h @@ -29,7 +29,6 @@ class GlobalOp; } // namespace memref namespace bufferization { -class BufferizationState; /// A simple analysis that detects allocation operations. class BufferPlacementAllocs { @@ -123,14 +122,9 @@ class BufferPlacementTransformationBase { // Globals are created lazily at the top of the enclosing ModuleOp with pretty // names. Duplicates are avoided. FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp, - SymbolTableCollection &symbolTables, uint64_t alignment, Attribute memorySpace = {}); -void removeSymbol(Operation *op, BufferizationState &state); - -void insertSymbol(Operation *op, BufferizationState &state); - } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h index 70e3defee0867..d5cb8d8eb673c 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -45,7 +45,6 @@ struct BufferizationStatistics { /// additional buffer copies or set "options.copyBeforeWrite = true". The /// general bufferization entry point is `runOneShotBufferize`. LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, - BufferizationState &bufferizationState, BufferizationStatistics *statistics = nullptr); /// Bufferize the signature of `block` and its callers (i.e., ops that have the diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h index 15189d2c1cb87..673027f76190d 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -270,7 +270,6 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, /// Run One-Shot Bufferize on the given op: Analysis + Bufferization LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, - BufferizationState &state, BufferizationStatistics *statistics = nullptr); } // namespace bufferization diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h index 2cf801dd1d951..4e5f5e9c730fa 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -20,7 +20,6 @@ namespace bufferization { struct BufferizationStatistics; class OneShotAnalysisState; struct OneShotBufferizationOptions; -class BufferizationState; /// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in /// `state`. @@ -39,7 +38,6 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, /// will be inserted only to these FuncOps. llvm::LogicalResult bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationState &state, BufferizationStatistics *statistics = nullptr); /// Remove bufferization attributes on every FuncOp arguments in the ModuleOp. @@ -52,7 +50,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp); llvm::LogicalResult runOneShotModuleBufferize( ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, - BufferizationState &state, BufferizationStatistics *statistics = nullptr); + BufferizationStatistics *statistics = nullptr); } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 2eef0a06d0eb4..4f90fc8831bc6 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -30,7 +30,6 @@ namespace mlir { namespace bufferization { class AllocTensorOp; class OneShotAnalysisState; -class BufferizationState; } // namespace bufferization namespace linalg { diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index f646326ffc58f..5e69a98db8f1e 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -24,8 +24,7 @@ struct ConstantOpInterface : public BufferizableOpInterface::ExternalModel<ConstantOpInterface, arith::ConstantOp> { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto constantOp = cast<arith::ConstantOp>(op); auto type = dyn_cast<RankedTensorType>(constantOp.getType()); @@ -47,8 +46,7 @@ struct ConstantOpInterface // Create global memory segment and replace tensor with memref pointing to // that memory segment. FailureOr<memref::GlobalOp> globalOp = - getGlobalFor(constantOp, state.getSymbolTables(), - options.bufferAlignment, memorySpace); + getGlobalFor(constantOp, options.bufferAlignment, memorySpace); if (failed(globalOp)) return failure(); memref::GlobalOp globalMemref = *globalOp; @@ -85,8 +83,7 @@ struct IndexCastOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto castOp = cast<arith::IndexCastOp>(op); auto resultTensorType = cast<TensorType>(castOp.getType()); @@ -134,8 +131,7 @@ struct SelectOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto selectOp = cast<arith::SelectOp>(op); Location loc = selectOp.getLoc(); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 14fa4c1ed8159..1fc34051680f1 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -125,10 +125,6 @@ void AnalysisState::resetCache() { insideMutuallyExclusiveRegionsCache.clear(); } -SymbolTableCollection &BufferizationState::getSymbolTables() { - return symbolTables; -} - Region *bufferization::getNextEnclosingRepetitiveRegion( Region *region, const BufferizationOptions &options) { assert(isRepetitiveRegion(region, options) && "expected repetitive region"); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 91eccb0ab7430..ecd2ef15546a4 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -149,8 +149,7 @@ void mlir::bufferization::populateDynamicDimSizes( //===----------------------------------------------------------------------===// LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) { + const BufferizationOptions &options) { OpBuilder::InsertionGuard g(rewriter); Location loc = getLoc(); @@ -530,8 +529,7 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, //===----------------------------------------------------------------------===// LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) { + const BufferizationOptions &options) { FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options); if (failed(buffer)) return failure(); @@ -578,8 +576,7 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand, LogicalResult MaterializeInDestinationOp::bufferize(RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) { + const BufferizationOptions &options) { bool tensorDest = isa<TensorType>(getDest().getType()); Value buffer; if (tensorDest) { @@ -864,8 +861,7 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results, } LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) { + const BufferizationOptions &options) { // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary. (void)foldToBufferToTensorPair(rewriter, *this, options); // Note: The return value of `bufferize` indicates whether there was an error diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index db1eb20512033..a1d7bb995fc73 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -83,8 +83,6 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, } auto payloadOps = state.getPayloadOps(getTarget()); - BufferizationState bufferizationState; - for (Operation *target : payloadOps) { if (!isa<ModuleOp, FunctionOpInterface>(target)) return emitSilenceableError() << "expected module or function target"; @@ -92,12 +90,10 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, if (options.bufferizeFunctionBoundaries) { if (!moduleOp) return emitSilenceableError() << "expected module target"; - if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options, - bufferizationState))) + if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) return emitSilenceableError() << "bufferization failed"; } else { - if (failed(bufferization::runOneShotBufferize(target, options, - bufferizationState))) + if (failed(bufferization::runOneShotBufferize(target, options))) return emitSilenceableError() << "bufferization failed"; } } @@ -166,7 +162,6 @@ class BufferizationTransformDialectExtension registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" - >(); } }; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp index ff2c83d228dbb..c2e90764b1335 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -103,9 +103,8 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase( //===----------------------------------------------------------------------===// FailureOr<memref::GlobalOp> -bufferization::getGlobalFor(arith::ConstantOp constantOp, - SymbolTableCollection &symbolTables, - uint64_t alignment, Attribute memorySpace) { +bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, + Attribute memorySpace) { auto type = cast<RankedTensorType>(constantOp.getType()); auto moduleOp = constantOp->getParentOfType<ModuleOp>(); if (!moduleOp) @@ -128,7 +127,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, // Create a builder without an insertion point. We will insert using the // symbol table to guarantee unique names. OpBuilder globalBuilder(moduleOp.getContext()); - SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp); + SymbolTable symbolTable(moduleOp); // Create a pretty name. SmallString<64> buf; @@ -159,19 +158,3 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, global->moveBefore(&moduleOp.front()); return global; } - -namespace mlir::bufferization { -void removeSymbol(Operation *op, BufferizationState &state) { - SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable( - op->getParentWithTrait<OpTrait::SymbolTable>()); - - symbolTable.remove(op); -} - -void insertSymbol(Operation *op, BufferizationState &state) { - SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable( - op->getParentWithTrait<OpTrait::SymbolTable>()); - - symbolTable.insert(op); -} -} // namespace mlir::bufferization diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 38de525316f7a..824b505517119 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -161,13 +161,10 @@ struct OneShotBufferizePass return signalPassFailure(); } - BufferizationState state; - BufferizationStatistics statistics; ModuleOp moduleOp = getOperation(); if (opt.bufferizeFunctionBoundaries) { - if (failed( - runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) { + if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) { signalPassFailure(); return; } @@ -178,7 +175,7 @@ struct OneShotBufferizePass "'bufferize-function-boundaries'"); return signalPassFailure(); } - if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) { + if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) { signalPassFailure(); return; } @@ -278,7 +275,6 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { LogicalResult bufferization::bufferizeOp(Operation *op, const BufferizationOptions &options, - BufferizationState &bufferizationState, BufferizationStatistics *statistics) { if (options.copyBeforeWrite) { AnalysisState state(options); @@ -335,8 +331,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op, << "//===-------------------------------------------===//\n" << "IR after bufferizing: " << nextOp->getName() << "\n"); rewriter.setInsertionPoint(nextOp); - if (failed( - bufferizableOp.bufferize(rewriter, options, bufferizationState))) { + if (failed(bufferizableOp.bufferize(rewriter, options))) { LLVM_DEBUG(llvm::dbgs() << "failed to bufferize\n" << "//===-------------------------------------------===//\n"); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 080796208bfc1..755477713668e 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -239,8 +239,7 @@ struct CallOpInterface /// All function arguments are writable. It is the responsibility of the /// CallOp to insert buffer copies where necessary. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { func::CallOp callOp = cast<func::CallOp>(op); // 1. Compute the result types of the new CallOp. @@ -350,8 +349,7 @@ struct ReturnOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { #ifndef NDEBUG auto returnOp = cast<func::ReturnOp>(op); assert(isa<FuncOp>(returnOp->getParentOp()) && @@ -420,8 +418,7 @@ struct FuncOpInterface /// All function bbArgs are writable unless they are explicitly marked as /// read-only. Callers must insert copies when needed. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto funcOp = cast<FuncOp>(op); FunctionType funcType = funcOp.getFunctionType(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index de820e9c8f8af..6e93b36d2d5a2 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -1365,9 +1365,10 @@ LogicalResult bufferization::analyzeOp(Operation *op, return success(!failedAnalysis); } -LogicalResult bufferization::runOneShotBufferize( - Operation *op, const OneShotBufferizationOptions &options, - BufferizationState &state, BufferizationStatistics *statistics) { +LogicalResult +bufferization::runOneShotBufferize(Operation *op, + const OneShotBufferizationOptions &options, + BufferizationStatistics *statistics) { // copy-before-write deactivates the analysis. It cannot be used together with // test-analysis-only. assert(!(options.copyBeforeWrite && options.testAnalysisOnly) && @@ -1390,5 +1391,5 @@ LogicalResult bufferization::runOneShotBufferize( // Bufferize the op and its nested ops. If options.copyBeforeWrite is set, // a new buffer copy is allocated every time a buffer is written to. - return bufferizeOp(op, options, state, statistics); + return bufferizeOp(op, options, statistics); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 90ceea4d69680..a025da8635135 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -512,7 +512,7 @@ void mlir::bufferization::removeBufferizationAttributesInModule( LogicalResult mlir::bufferization::bufferizeModuleOp( ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationState &state, BufferizationStatistics *statistics) { + BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); IRRewriter rewriter(moduleOp.getContext()); @@ -548,10 +548,10 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( // Buffer copies must be inserted before every write. OneShotBufferizationOptions updatedOptions = options; updatedOptions.copyBeforeWrite = true; - if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics))) + if (failed(bufferizeOp(funcOp, updatedOptions, statistics))) return failure(); } else { - if (failed(bufferizeOp(funcOp, options, state, statistics))) + if (failed(bufferizeOp(funcOp, options, statistics))) return failure(); } @@ -565,7 +565,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( // Functions were already bufferized. if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>()) continue; - if (failed(bufferizeOp(&op, options, state, statistics))) + if (failed(bufferizeOp(&op, options, statistics))) return failure(); } @@ -577,7 +577,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( LogicalResult mlir::bufferization::runOneShotModuleBufferize( ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationState &state, BufferizationStatistics *statistics) { + BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); assert(!(options.copyBeforeWrite && options.testAnalysisOnly) && @@ -606,7 +606,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( } if (options.testAnalysisOnly) return success(); - if (failed(bufferizeModuleOp(moduleOp, options, state, statistics))) + if (failed(bufferizeModuleOp(moduleOp, options, statistics))) return failure(); return success(); } diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp index 6a1546fb48683..72f4a1a4f4c66 100644 --- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp @@ -43,8 +43,7 @@ struct BranchLikeOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { // The operands of this op are bufferized together with the block signature. return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp index b6a498a57c036..be158af09d398 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -148,8 +148,7 @@ struct LinalgOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { return bufferizeDestinationStyleOpInterface( rewriter, cast<DestinationStyleOpInterface>(op), options); } @@ -175,8 +174,7 @@ struct SoftmaxOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto softmaxOp = cast<linalg::SoftmaxOp>(op); FailureOr<Value> inputBuffer = getBuffer(rewriter, softmaxOp.getInput(), options); @@ -204,7 +202,6 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels( LinalgOpInterfaceHelper< #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" - >::registerOpInterface(ctx); SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index 94a4b9011c16b..a62510deefc4a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -263,11 +263,7 @@ Value linalg::bufferizeToAllocation( assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 && "expected single masked op"); OpBuilder::InsertionGuard g(rewriter); - - // Should the bufferization options and state be function arguments? bufferization::BufferizationOptions bufferizationOptions; - bufferization::BufferizationState bufferizationState; - Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator(); assert(isa<vector::YieldOp>(yieldOp) && "expected yield op terminator"); @@ -283,7 +279,7 @@ Value linalg::bufferizeToAllocation( // Bufferize terminator. rewriter.setInsertionPoint(yieldOp); if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize( - rewriter, bufferizationOptions, bufferizationState))) + rewriter, bufferizationOptions))) return nullptr; // Erase dead to_tensor ops inside of the mask op. This is necessary because @@ -304,9 +300,8 @@ Value linalg::bufferizeToAllocation( for (OpOperand &use : result.getUses()) resultUses.push_back(&use); rewriter.setInsertionPoint(maskOp); - if (failed( - cast<bufferization::BufferizableOpInterface>(maskOp.getOperation()) - .bufferize(rewriter, bufferizationOptions, bufferizationState))) + if (failed(cast<bufferization::BufferizableOpInterface>(maskOp.getOperation()) + .bufferize(rewriter, bufferizationOptions))) return nullptr; // Set "restrict" attribute, indicating that no other tensor aliases with @@ -489,11 +484,8 @@ Value linalg::bufferizeToAllocation( auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); if (!bufferizableOp) return nullptr; - - // Should the bufferization options and states be function arguments? BufferizationOptions bufferizationOptions; - AnalysisState analysisState(bufferizationOptions); - BufferizationState bufferizationState; + AnalysisState state(bufferizationOptions); #ifndef NDEBUG if (!options.bufferizeDestinationOnly) { @@ -535,7 +527,7 @@ Value linalg::bufferizeToAllocation( }; for (OpResult result : tensorResults) { AliasingOpOperandList aliasingOperands = - analysisState.getAliasingOpOperands(result); + state.getAliasingOpOperands(result); for (const AliasingOpOperand &operand : aliasingOperands) { addOutOfPlaceOperand(operand.opOperand); for (OpOperand &resultUse : result.getUses()) @@ -543,7 +535,7 @@ Value linalg::bufferizeToAllocation( } } for (OpOperand &operand : op->getOpOperands()) { - if (!analysisState.bufferizesToMemoryWrite(operand)) + if (!state.bufferizesToMemoryWrite(operand)) continue; if (!isa<RankedTensorType>(operand.get().getType())) continue; @@ -561,7 +553,7 @@ Value linalg::bufferizeToAllocation( Value alloc = createAllocationForTensor( rewriter, op->getLoc(), operand->get(), options, memorySpace); allocs.push_back(alloc); - if (!analysisState.findDefinitions(operand).empty()) { + if (!state.findDefinitions(operand).empty()) { // Initialize buffer with a copy of the operand data. Not needed if the // tensor is uninitialized. createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options); @@ -583,8 +575,7 @@ Value linalg::bufferizeToAllocation( // Bufferize the op. rewriter.setInsertionPoint(op); - if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions, - bufferizationState))) + if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions))) return nullptr; // Set "restrict" attribute, indicating that no other tensor aliases with diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp index a69bc9e5088ae..926d580ac7852 100644 --- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp @@ -9,7 +9,6 @@ #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -53,18 +52,15 @@ struct GlobalOpInterface bool hasTensorSemantics(Operation *) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &, - BufferizationState &state) const { + const BufferizationOptions &) const { auto globalOp = cast<GlobalOp>(op); if (!globalOp.getValue().has_value()) return globalOp.emitError("global op must have a value"); - bufferization::removeSymbol(globalOp, state); - auto tensorType = cast<TensorType>(globalOp.getType()); auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); - auto replacement = replaceOpWithNewBufferizedOp<memref::GlobalOp>( + replaceOpWithNewBufferizedOp<memref::GlobalOp>( rewriter, globalOp, globalOp.getSymName(), /*sym_visibility=*/globalOp.getSymVisibilityAttr(), /*type=*/cast<MemRefType>(memrefType), @@ -72,7 +68,6 @@ struct GlobalOpInterface /*constant=*/!globalOp.getIsMutable(), /*alignment=*/nullptr); - bufferization::insertSymbol(replacement, state); return success(); } }; @@ -96,8 +91,7 @@ struct GlobalLoadOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &, - BufferizationState &state) const { + const BufferizationOptions &) const { auto globalLoadOp = cast<GlobalLoadOp>(op); auto tensorType = cast<TensorType>(globalLoadOp.getType()); @@ -127,8 +121,7 @@ struct GlobalStoreOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto globalStoreOp = cast<GlobalStoreOp>(op); auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType()); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 3ff1f5c49aece..d6a9d8f6401f1 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -95,8 +95,7 @@ struct ConditionOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto conditionOp = cast<scf::ConditionOp>(op); auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp()); @@ -182,8 +181,7 @@ struct ExecuteRegionOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); auto yieldOp = getUniqueYieldOp(executeRegionOp); TypeRange newResultTypes(yieldOp.getResults()); @@ -239,8 +237,7 @@ struct IfOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); auto ifOp = cast<scf::IfOp>(op); @@ -350,8 +347,7 @@ struct IndexSwitchOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); auto switchOp = cast<scf::IndexSwitchOp>(op); @@ -726,8 +722,7 @@ struct ForOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto forOp = cast<scf::ForOp>(op); Block *oldLoopBody = forOp.getBody(); @@ -944,8 +939,7 @@ struct WhileOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto whileOp = cast<scf::WhileOp>(op); // Indices of all bbArgs that have tensor type. These are the ones that @@ -1150,8 +1144,7 @@ struct YieldOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto yieldOp = cast<scf::YieldOp>(op); if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp, scf::WhileOp>(yieldOp->getParentOp())) @@ -1227,8 +1220,7 @@ struct ForallOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { OpBuilder::InsertionGuard guard(rewriter); auto forallOp = cast<ForallOp>(op); int64_t rank = forallOp.getRank(); @@ -1335,8 +1327,7 @@ struct InParallelOpInterface : public BufferizableOpInterface::ExternalModel<InParallelOpInterface, InParallelOp> { LogicalResult bufferize(Operation *op, RewriterBase &b, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { llvm_unreachable("op does not have any tensor OpOperands / OpResults"); return failure(); } diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp index e8cab76d3c753..6c3b23937f98f 100644 --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -47,8 +47,7 @@ struct AssumingOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto assumingOp = cast<shape::AssumingOp>(op); assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) && "only 1 block supported"); @@ -113,8 +112,7 @@ struct AssumingYieldOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto yieldOp = cast<shape::AssumingYieldOp>(op); SmallVector<Value> newResults; for (Value value : yieldOp.getOperands()) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp index f952b68ba7e67..7734d1d258453 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -30,8 +30,7 @@ template <typename ConcreteModel, typename ConcreteOp> struct SparseBufferizableOpInterfaceExternalModel : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { return op->emitError( "sparse_tensor ops must be bufferized with the sparsifier"); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 7c7c64f2aef01..6e882a8d0ff30 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -114,11 +114,8 @@ class SparsificationAndBufferizationPass return false; }); - bufferization::BufferizationState bufferizationState; - if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()), - updatedOptions, - bufferizationState))) + updatedOptions))) return failure(); bufferization::removeBufferizationAttributesInModule(getOperation()); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 64443d210e163..b6843e560a899 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -83,8 +83,7 @@ struct CastOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto castOp = cast<tensor::CastOp>(op); // The result buffer still has the old (pre-cast) type. @@ -163,8 +162,7 @@ struct CollapseShapeOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); RankedTensorType tensorResultType = collapseShapeOp.getResultType(); FailureOr<Value> maybeBuffer = @@ -249,8 +247,7 @@ struct DimOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto dimOp = cast<tensor::DimOp>(op); FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options); if (failed(v)) @@ -274,8 +271,7 @@ struct EmptyOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto emptyOp = cast<tensor::EmptyOp>(op); // Optimization: Fold away the op if it has no uses. @@ -333,8 +329,7 @@ struct ExpandShapeOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); auto tensorResultType = expandShapeOp.getResultType(); FailureOr<Value> buffer = @@ -372,8 +367,7 @@ struct ExtractSliceOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); @@ -438,8 +432,7 @@ struct ExtractOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto extractOp = cast<tensor::ExtractOp>(op); FailureOr<Value> srcMemref = getBuffer(rewriter, extractOp.getTensor(), options); @@ -481,8 +474,7 @@ struct FromElementsOpInterface bool bufferizesToAllocation(Operation *op, Value value) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto fromElementsOp = cast<tensor::FromElementsOp>(op); auto tensorType = cast<RankedTensorType>(fromElementsOp.getType()); @@ -594,8 +586,7 @@ struct GenerateOpInterface bool bufferizesToAllocation(Operation *op, Value value) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto generateOp = cast<tensor::GenerateOp>(op); auto type = generateOp.getResult().getType(); @@ -629,8 +620,7 @@ struct InsertOpInterface : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface, tensor::InsertOp> { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto insertOp = cast<tensor::InsertOp>(op); FailureOr<Value> destMemref = getBuffer(rewriter, insertOp.getDest(), options); @@ -680,8 +670,7 @@ struct InsertSliceOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { // insert_slice ops arise from tiling and bufferizing them out-of-place is // generally a deal breaker. When used with loops, this ends up cloning the // whole tensor on every single iteration and is a symptom of a @@ -763,8 +752,7 @@ struct PadOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto padOp = cast<tensor::PadOp>(op); Location loc = padOp.getLoc(); RankedTensorType resultType = padOp.getResultType(); @@ -843,8 +831,7 @@ struct RankOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto rankOp = cast<tensor::RankOp>(op); FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options); if (failed(v)) @@ -881,8 +868,7 @@ struct ReshapeOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto reshapeOp = cast<tensor::ReshapeOp>(op); FailureOr<Value> srcBuffer = getBuffer(rewriter, reshapeOp.getSource(), options); @@ -954,8 +940,7 @@ struct ParallelInsertSliceOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); ParallelCombiningOpInterface parallelCombiningParent = @@ -1030,8 +1015,7 @@ struct SplatOpInterface bool bufferizesToAllocation(Operation *op, Value value) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); auto splatOp = cast<tensor::SplatOp>(op); diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index 45b6e7c512947..b2272c5fda876 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -48,8 +48,7 @@ struct TransferReadOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto readOp = cast<vector::TransferReadOp>(op); assert(isa<TensorType>(readOp.getShapedType()) && "only tensor types expected"); @@ -104,8 +103,7 @@ struct TransferWriteOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto writeOp = cast<vector::TransferWriteOp>(op); assert(isa<TensorType>(writeOp.getShapedType()) && "only tensor types expected"); @@ -150,8 +148,7 @@ struct GatherOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto gatherOp = cast<vector::GatherOp>(op); assert(isa<TensorType>(gatherOp.getBaseType()) && "only tensor types expected"); @@ -205,8 +202,7 @@ struct MaskOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto maskOp = cast<vector::MaskOp>(op); // Do not bufferize if the masked op is not bufferizable. @@ -283,8 +279,7 @@ struct YieldOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto yieldOp = cast<vector::YieldOp>(op); // Only supported as a vector.mask terminator. _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits