https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/115816
This commit adds a new function `ConversionPatternRewriter::replaceOpWithMultiple`. This function is similar to `replaceOp`, but it accepts multiple `ValueRange` replacements, one per op result. Note: This function is not an overload of `replaceOp` because of ambiguous overload resolution that would make the API difficult to use. This commit aligns "block signature conversions" with "op replacements": both support 1:N replacements now. Due to incomplete 1:N support in the dialect conversion driver, an argument materialization is inserted when an SSA value is replaced with multiple values; same as block signature conversions already work around the problem. These argument materializations are going to be removed in a subsequent commit that adds full 1:N support. The purpose of this PR is to add missing features gradually in small increments. This commit also updates two MLIR transformations that have their custom workarounds around missing 1:N support. These can already start using `replaceOpWithMultiple`. Depends on #114940. >From b425caab826e5d9ad2f078d6f548f3215005bf7f Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Tue, 12 Nov 2024 05:14:43 +0100 Subject: [PATCH] replace with multiple --- mlir/include/mlir/IR/Builders.h | 2 +- .../mlir/Transforms/DialectConversion.h | 24 ++- .../Transforms/DecomposeCallGraphTypes.cpp | 40 ++--- .../Transforms/SparseTensorCodegen.cpp | 48 +++--- .../Utils/SparseTensorDescriptor.cpp | 21 ++- mlir/lib/IR/Builders.cpp | 16 +- .../Transforms/Utils/DialectConversion.cpp | 153 ++++++++++++------ 7 files changed, 186 insertions(+), 118 deletions(-) diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 7ef03b87179523..78729376507208 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -353,7 +353,7 @@ class OpBuilder : public Builder { /// selected insertion point. (E.g., because they are defined in a nested /// region or because they are not visible in an IsolatedFromAbove region.) static InsertPoint after(ArrayRef<Value> values, - const PostDominanceInfo &domInfo); + const PostDominanceInfo *domInfo = nullptr); /// Returns true if this insert point is set. bool isSet() const { return (block != nullptr); } diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 5e5957170e646c..e461b7d11602a0 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -795,12 +795,32 @@ class ConversionPatternRewriter final : public PatternRewriter { /// patterns even if a failure is encountered during the rewrite step. bool canRecoverFromRewriteFailure() const override { return true; } - /// PatternRewriter hook for replacing an operation. + /// Replace the given operation with the new values. The number of op results + /// and replacement values must match. The types may differ: the dialect + /// conversion driver will reconcile any surviving type mismatches at the end + /// of the conversion process with source materializations. The given + /// operation is erased. void replaceOp(Operation *op, ValueRange newValues) override; - /// PatternRewriter hook for replacing an operation. + /// Replace the given operation with the results of the new op. The number of + /// op results must match. The types may differ: the dialect conversion + /// driver will reconcile any surviving type mismatches at the end of the + /// conversion process with source materializations. The original operation + /// is erased. void replaceOp(Operation *op, Operation *newOp) override; + /// Replace the given operation with the new value range. The number of op + /// results and value ranges must match. If an original SSA value is replaced + /// by multiple SSA values (i.e., value range has more than 1 element), the + /// conversion driver will insert an argument materialization to convert the + /// N SSA values back into 1 SSA value of the original type. The given + /// operation is erased. + /// + /// Note: The argument materialization is a workaround until we have full 1:N + /// support in the dialect conversion. (It is going to disappear from both + /// `replaceOpWithMultiple` and `applySignatureConversion`.) + void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues); + /// PatternRewriter hook for erasing a dead operation. The uses of this /// operation *must* be made dead by the end of the conversion process, /// otherwise an assert will be issued. diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp index de4aba2ed327db..a08764326a80b6 100644 --- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp @@ -141,47 +141,31 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> { getTypeConverter())); } - // Create the new result types for the new `CallOp` and track the indices in - // the new call op's results that correspond to the old call op's results. - // - // expandedResultIndices[i] = "list of new result indices that old result i - // expanded to". + // Create the new result types for the new `CallOp` and track the number of + // replacement types for each original op result. SmallVector<Type, 2> newResultTypes; - SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices; + SmallVector<unsigned> expandedResultSizes; for (Type resultType : op.getResultTypes()) { unsigned oldSize = newResultTypes.size(); if (failed(typeConverter->convertType(resultType, newResultTypes))) return failure(); - auto &resultMapping = expandedResultIndices.emplace_back(); - for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++) - resultMapping.push_back(i); + expandedResultSizes.push_back(newResultTypes.size() - oldSize); } CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(), newResultTypes, newOperands); - // Build a replacement value for each result to replace its uses. If a - // result has multiple mapping values, it needs to be materialized as a - // single value. - SmallVector<Value, 2> replacedValues; + // Build a replacement value for each result to replace its uses. + SmallVector<ValueRange> replacedValues; replacedValues.reserve(op.getNumResults()); + unsigned startIdx = 0; for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { - auto decomposedValues = llvm::to_vector<6>( - llvm::map_range(expandedResultIndices[i], - [&](unsigned i) { return newCallOp.getResult(i); })); - if (decomposedValues.empty()) { - // No replacement is required. - replacedValues.push_back(nullptr); - } else if (decomposedValues.size() == 1) { - replacedValues.push_back(decomposedValues.front()); - } else { - // Materialize a single Value to replace the original Value. - Value materialized = getTypeConverter()->materializeArgumentConversion( - rewriter, op.getLoc(), op.getType(i), decomposedValues); - replacedValues.push_back(materialized); - } + ValueRange repl = + newCallOp.getResults().slice(startIdx, expandedResultSizes[i]); + replacedValues.push_back(repl); + startIdx += expandedResultSizes[i]; } - rewriter.replaceOp(op, replacedValues); + rewriter.replaceOpWithMultiple(op, replacedValues); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 062a0ea6cc47cb..09509278d7749a 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -600,8 +600,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> { flattenOperands(adaptor.getOperands(), flattened); auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(), finalRetTy, flattened); - // (2) Create cast operation for sparse tensor returns. - SmallVector<Value> castedRet; + // (2) Gather sparse tensor returns. + SmallVector<SmallVector<Value>> packedResultVals; // Tracks the offset of current return value (of the original call) // relative to the new call (after sparse tensor flattening); unsigned retOffset = 0; @@ -618,21 +618,27 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> { assert(!sparseFlat.empty()); if (sparseFlat.size() > 1) { auto flatSize = sparseFlat.size(); - ValueRange fields(iterator_range<ResultRange::iterator>( - newCall.result_begin() + retOffset, - newCall.result_begin() + retOffset + flatSize)); - castedRet.push_back(genTuple(rewriter, loc, retType, fields)); + packedResultVals.push_back(SmallVector<Value>()); + llvm::append_range(packedResultVals.back(), + iterator_range<ResultRange::iterator>( + newCall.result_begin() + retOffset, + newCall.result_begin() + retOffset + flatSize)); retOffset += flatSize; } else { // If this is an 1:1 conversion, no need for casting. - castedRet.push_back(newCall.getResult(retOffset)); + packedResultVals.push_back(SmallVector<Value>()); + packedResultVals.back().push_back(newCall.getResult(retOffset)); retOffset++; } sparseFlat.clear(); } - assert(castedRet.size() == op.getNumResults()); - rewriter.replaceOp(op, castedRet); + assert(packedResultVals.size() == op.getNumResults()); + SmallVector<ValueRange> ranges; + ranges.reserve(packedResultVals.size()); + for (const SmallVector<Value> &vec : packedResultVals) + ranges.push_back(ValueRange(vec)); + rewriter.replaceOpWithMultiple(op, ranges); return success(); } }; @@ -776,7 +782,7 @@ class SparseTensorAllocConverter // Reuses specifier. fields.push_back(desc.getSpecifier()); assert(fields.size() == desc.getNumFields()); - rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); + rewriter.replaceOpWithMultiple(op, {fields}); return success(); } @@ -796,7 +802,7 @@ class SparseTensorAllocConverter sizeHint, lvlSizesValues, fields); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); + rewriter.replaceOpWithMultiple(op, {fields}); return success(); } @@ -837,7 +843,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { sizeHint, lvlSizesValues, fields); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); + rewriter.replaceOpWithMultiple(op, {fields}); return success(); } @@ -893,7 +899,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { if (op.getHasInserts()) genEndInsert(rewriter, op.getLoc(), desc); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc)); + rewriter.replaceOpWithMultiple(op, {desc.getFields()}); return success(); } }; @@ -1006,7 +1012,6 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> { rewriter.create<scf::YieldOp>(loc, insertRet); rewriter.setInsertionPointAfter(loop); - Value result = genTuple(rewriter, loc, dstType, loop->getResults()); // Deallocate the buffers on exit of the full loop nest. Operation *parent = getTop(op); rewriter.setInsertionPointAfter(parent); @@ -1014,7 +1019,7 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> { rewriter.create<memref::DeallocOp>(loc, filled); rewriter.create<memref::DeallocOp>(loc, added); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, result); + rewriter.replaceOpWithMultiple(op, {loop->getResults()}); return success(); } }; @@ -1041,8 +1046,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> { params, /*genCall=*/true); SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, - genTuple(rewriter, loc, op.getDest().getType(), ret)); + rewriter.replaceOpWithMultiple(op, {ret}); return success(); } }; @@ -1215,8 +1219,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> { return true; }); - rewriter.replaceOp( - op, genTuple(rewriter, loc, op.getResult().getType(), fields)); + rewriter.replaceOpWithMultiple(op, {fields}); return success(); } }; @@ -1271,8 +1274,7 @@ class SparseExtractSliceConverter // NOTE: we can not generate tuples directly from descriptor here, as the // descriptor is holding the original type, yet we want the slice type // here (they shared every memref but with an updated specifier). - rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(), - desc.getFields())); + rewriter.replaceOpWithMultiple(op, {desc.getFields()}); return success(); } }; @@ -1403,7 +1405,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> { } desc.setValMemSize(rewriter, loc, memSize); - rewriter.replaceOp(op, genTuple(rewriter, loc, desc)); + rewriter.replaceOpWithMultiple(op, {desc.getFields()}); return success(); } }; @@ -1577,7 +1579,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> { EmitCInterface::Off); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields)); + rewriter.replaceOpWithMultiple(op, {fields}); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp index a3db50573c2720..834e3634cc130d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp @@ -54,19 +54,24 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) { // The sparse tensor type converter (defined in Passes.h). //===----------------------------------------------------------------------===// +static Value materializeTuple(OpBuilder &builder, RankedTensorType tp, + ValueRange inputs, Location loc) { + if (!getSparseTensorEncoding(tp)) + // Not a sparse tensor. + return Value(); + // Sparsifier knows how to cancel out these casts. + return genTuple(builder, loc, tp, inputs); +} + SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { addConversion([](Type type) { return type; }); addConversion(convertSparseTensorType); // Required by scf.for 1:N type conversion. - addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp, - ValueRange inputs, Location loc) -> Value { - if (!getSparseTensorEncoding(tp)) - // Not a sparse tensor. - return Value(); - // Sparsifier knows how to cancel out these casts. - return genTuple(builder, loc, tp, inputs); - }); + addSourceMaterialization(materializeTuple); + + // Required as a workaround until we have full 1:N support. + addArgumentMaterialization(materializeTuple); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 4714c3cace6c78..e85a86e94282ec 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -645,7 +645,7 @@ void OpBuilder::cloneRegionBefore(Region ®ion, Block *before) { OpBuilder::InsertPoint OpBuilder::InsertPoint::after(ArrayRef<Value> values, - const PostDominanceInfo &domInfo) { + const PostDominanceInfo *domInfo) { // Helper function that computes the point after v's definition. auto computeAfterIp = [](Value v) -> std::pair<Block *, Block::iterator> { if (auto blockArg = dyn_cast<BlockArgument>(v)) @@ -658,12 +658,18 @@ OpBuilder::InsertPoint::after(ArrayRef<Value> values, assert(!values.empty() && "expected at least one Value"); auto [block, blockIt] = computeAfterIp(values.front()); + if (values.size() == 1) { + // Fast path: There is only one value. + return InsertPoint(block, blockIt); + } + // Check the other values one-by-one and update the insertion point if // needed. + assert(domInfo && "domInfo expected if >1 values"); for (Value v : values.drop_front()) { auto [candidateBlock, candidateBlockIt] = computeAfterIp(v); - if (domInfo.postDominantes(candidateBlock, candidateBlockIt, block, - blockIt)) { + if (domInfo->postDominantes(candidateBlock, candidateBlockIt, block, + blockIt)) { // The point after v's definition post-dominates the current (and all // previous) insertion points. Note: Post-dominance is transitive. block = candidateBlock; @@ -671,8 +677,8 @@ OpBuilder::InsertPoint::after(ArrayRef<Value> values, continue; } - if (!domInfo.postDominantes(block, blockIt, candidateBlock, - candidateBlockIt)) { + if (!domInfo->postDominantes(block, blockIt, candidateBlock, + candidateBlockIt)) { // The point after v's definition and the current insertion point do not // post-dominate each other. Therefore, there is no insertion point that // post-dominates all values. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 0a62628b9ad240..2f6c0a1ab0bd3b 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Iterators.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -53,20 +54,14 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { }); } -/// Helper function that computes an insertion point where the given value is -/// defined and can be used without a dominance violation. -static OpBuilder::InsertPoint computeInsertPoint(Value value) { - Block *insertBlock = value.getParentBlock(); - Block::iterator insertPt = insertBlock->begin(); - if (OpResult inputRes = dyn_cast<OpResult>(value)) - insertPt = ++inputRes.getOwner()->getIterator(); - return OpBuilder::InsertPoint(insertBlock, insertPt); -} - //===----------------------------------------------------------------------===// // ConversionValueMapping //===----------------------------------------------------------------------===// +/// A list of replacement SSA values. Optimized for the common case of a single +/// SSA value. +using ReplacementValues = SmallVector<Value, 1>; + namespace { /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. @@ -818,6 +813,22 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { Type originalType, const TypeConverter *converter); + /// Build an N:1 materialization for the given original value that was + /// replaced with the given replacement values. + /// + /// This is a workaround around incomplete 1:N support in the dialect + /// conversion driver. The conversion mapping can store only 1:1 replacements + /// and the conversion patterns only support single Value replacements in the + /// adaptor, so N values must be converted back to a single value. This + /// function will be deleted when full 1:N support has been added. + /// + /// This function inserts an argument materialization back to the original + /// type, followed by a target materialization to the legalized type (if + /// applicable). + void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc, + ValueRange replacements, Value originalValue, + const TypeConverter *converter); + //===--------------------------------------------------------------------===// // Rewriter Notification Hooks //===--------------------------------------------------------------------===// @@ -827,7 +838,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { OpBuilder::InsertPoint previous) override; /// Notifies that an op is about to be replaced with the given values. - void notifyOpReplaced(Operation *op, ValueRange newValues); + void notifyOpReplaced(Operation *op, ArrayRef<ReplacementValues> newValues); /// Notifies that a block is about to be erased. void notifyBlockIsBeingErased(Block *block); @@ -1147,8 +1158,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // that the value was replaced with a value of different type and no // source materialization was created yet. Value castValue = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(newOperand), - operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType, + MaterializationKind::Target, + OpBuilder::InsertPoint::after(newOperand), operandLoc, + /*inputs=*/newOperand, /*outputType=*/desiredType, /*originalType=*/origType, currentTypeConverter); mapping.map(newOperand, castValue); newOperand = castValue; @@ -1287,33 +1299,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // used as a replacement. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - Value argMat = buildUnresolvedMaterialization( - MaterializationKind::Argument, + insertNTo1Materialization( OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), - /*inputs=*/replArgs, /*outputType=*/origArgType, - /*originalType=*/Type(), converter); - mapping.map(origArg, argMat); - - Type legalOutputType; - if (converter) { - legalOutputType = converter->convertType(origArgType); - } else if (replArgs.size() == 1) { - // When there is no type converter, assume that the new block argument - // types are legal. This is reasonable to assume because they were - // specified by the user. - // FIXME: This won't work for 1->N conversions because multiple output - // types are not supported in parts of the dialect conversion. In such a - // case, we currently use the original block argument type (produced by - // the argument materialization). - legalOutputType = replArgs[0].getType(); - } - if (legalOutputType && legalOutputType != origArgType) { - Value targetMat = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(argMat), - origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType, - /*originalType=*/origArgType, converter); - mapping.map(argMat, targetMat); - } + /*replacements=*/replArgs, /*outputValue=*/origArg, converter); appendRewrite<ReplaceBlockArgRewrite>(block, origArg); } @@ -1354,6 +1342,39 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( return convertOp.getResult(0); } +void ConversionPatternRewriterImpl::insertNTo1Materialization( + OpBuilder::InsertPoint ip, Location loc, ValueRange replacements, + Value originalValue, const TypeConverter *converter) { + // Insert argument materialization back to the original type. + Type originalType = originalValue.getType(); + Value argMat = + buildUnresolvedMaterialization(MaterializationKind::Argument, ip, loc, + /*inputs=*/replacements, originalType, + /*originalType=*/Type(), converter); + mapping.map(originalValue, argMat); + + // Insert target materialization to the legalized type. + Type legalOutputType; + if (converter) { + legalOutputType = converter->convertType(originalType); + } else if (replacements.size() == 1) { + // When there is no type converter, assume that the replacement value + // types are legal. This is reasonable to assume because they were + // specified by the user. + // FIXME: This won't work for 1->N conversions because multiple output + // types are not supported in parts of the dialect conversion. In such a + // case, we currently use the original value type. + legalOutputType = replacements[0].getType(); + } + if (legalOutputType && legalOutputType != originalType) { + Value targetMat = buildUnresolvedMaterialization( + MaterializationKind::Target, OpBuilder::InsertPoint::after(argMat), loc, + /*inputs=*/argMat, /*outputType=*/legalOutputType, + /*originalType=*/originalType, converter); + mapping.map(argMat, targetMat); + } +} + //===----------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1377,10 +1398,11 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp); } -void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, - ValueRange newValues) { +void ConversionPatternRewriterImpl::notifyOpReplaced( + Operation *op, ArrayRef<ReplacementValues> newValues) { assert(newValues.size() == op->getNumResults()); assert(!ignoredOps.contains(op) && "operation was already replaced"); + PostDominanceInfo domInfo; // Check if replaced op is an unresolved materialization, i.e., an // unrealized_conversion_cast op that was created by the conversion driver. @@ -1390,8 +1412,9 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, isUnresolvedMaterialization = true; // Create mappings for each of the new result values. - for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) { - if (!newValue) { + for (auto [n, result] : llvm::zip(newValues, op->getResults())) { + ReplacementValues repl = n; + if (repl.empty()) { // This result was dropped and no replacement value was provided. if (isUnresolvedMaterialization) { // Do not create another materializations if we are erasing a @@ -1400,11 +1423,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, } // Materialize a replacement value "out of thin air". - newValue = buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(result), + Value sourceMat = buildUnresolvedMaterialization( + MaterializationKind::Source, OpBuilder::InsertPoint::after(result), result.getLoc(), /*inputs=*/ValueRange(), /*outputType=*/result.getType(), /*originalType=*/Type(), currentTypeConverter); + repl.push_back(sourceMat); } else { // Make sure that the user does not mess with unresolved materializations // that were inserted by the conversion driver. We keep track of these @@ -1417,12 +1441,21 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, } // Remap result to replacement value. - if (newValue) - mapping.map(result, newValue); + if (!repl.empty()) { + if (repl.size() == 1) { + // Single replacement value: replace directly. + mapping.map(result, repl.front()); + } else { + // Multiple replacement values: insert N:1 materialization. + insertNTo1Materialization(OpBuilder::InsertPoint::after(repl, &domInfo), + result.getLoc(), + /*replacements=*/repl, /*outputValue=*/result, + currentTypeConverter); + } + } } appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter); - // Mark this operation and all nested ops as replaced. op->walk([&](Operation *op) { replacedOps.insert(op); }); } @@ -1497,7 +1530,25 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); - impl->notifyOpReplaced(op, newValues); + SmallVector<ReplacementValues> newVals(newValues.size(), {}); + for (auto it : llvm::enumerate(newValues)) + if (Value val = it.value()) + newVals[it.index()].push_back(val); + impl->notifyOpReplaced(op, newVals); +} + +void ConversionPatternRewriter::replaceOpWithMultiple( + Operation *op, ArrayRef<ValueRange> newValues) { + assert(op->getNumResults() == newValues.size() && + "incorrect # of replacement values"); + LLVM_DEBUG({ + impl->logger.startLine() + << "** Replace : '" << op->getName() << "'(" << op << ")\n"; + }); + SmallVector<ReplacementValues> newVals(newValues.size(), {}); + for (auto it : llvm::enumerate(newValues)) + llvm::append_range(newVals[it.index()], it.value()); + impl->notifyOpReplaced(op, newVals); } void ConversionPatternRewriter::eraseOp(Operation *op) { @@ -1505,7 +1556,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { impl->logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); - SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr); + SmallVector<ReplacementValues> nullRepls(op->getNumResults(), {}); impl->notifyOpReplaced(op, nullRepls); } @@ -2596,7 +2647,7 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) { Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue); assert(newValue && "replacement value not found"); Value castValue = rewriterImpl.buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(newValue), + MaterializationKind::Source, OpBuilder::InsertPoint::after(newValue), originalValue.getLoc(), /*inputs=*/newValue, /*outputType=*/originalValue.getType(), /*originalType=*/Type(), converter); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits