https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116524
>From e3946a5496cdf64ff6a8a5c7e1b117f4904ac9e5 Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Sun, 17 Nov 2024 04:38:09 +0100 Subject: [PATCH] [mlir][Transforms] Support 1:N mappings in `ConversionValueMapping` --- .../Conversion/LLVMCommon/TypeConverter.cpp | 68 ++- .../Bufferization/Transforms/Bufferize.cpp | 1 - .../EmitC/Transforms/TypeConversions.cpp | 1 - .../Dialect/Linalg/Transforms/Detensorize.cpp | 1 - .../Quant/Transforms/StripFuncQuantTypes.cpp | 1 - .../Utils/SparseTensorDescriptor.cpp | 3 - .../Vector/Transforms/VectorLinearize.cpp | 1 - .../Transforms/Utils/DialectConversion.cpp | 527 ++++++++++-------- mlir/test/Transforms/test-legalizer.mlir | 3 - .../Func/TestDecomposeCallGraphTypes.cpp | 2 +- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 11 +- .../lib/Transforms/TestDialectConversion.cpp | 1 - 12 files changed, 335 insertions(+), 285 deletions(-) diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 59b0f5c9b09bcd..fbf1c20d0baa32 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -153,20 +153,31 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, type.isVarArg()); }); + // Add generic source and target materializations to handle cases where + // non-LLVM types persist after an LLVM conversion. + addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) + .getResult(0); + }); + addTargetMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) + .getResult(0); + }); + // Helper function that checks if the given value range is a bare pointer. auto isBarePointer = [](ValueRange values) { return values.size() == 1 && isa<LLVM::LLVMPointerType>(values.front().getType()); }; - // Argument materializations convert from the new block argument types - // (multiple SSA values that make up a memref descriptor) back to the - // original block argument type. The dialect conversion framework will then - // insert a target materialization from the original block argument type to - // a legal type. - addArgumentMaterialization([&](OpBuilder &builder, - UnrankedMemRefType resultType, - ValueRange inputs, Location loc) { + // Source materializations convert the MemrRef descriptor elements + // (multiple SSA values that make up a MemrRef descriptor) back to the + // original MemRef type. + addSourceMaterialization([&](OpBuilder &builder, + UnrankedMemRefType resultType, ValueRange inputs, + Location loc) { // Note: Bare pointers are not supported for unranked memrefs because a // memref descriptor cannot be built just from a bare pointer. if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields()) @@ -179,8 +190,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc) .getResult(0); }); - addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, - ValueRange inputs, Location loc) { + addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType, + ValueRange inputs, Location loc) { Value desc; if (isBarePointer(inputs)) { desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType, @@ -200,23 +211,30 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc) .getResult(0); }); - // Add generic source and target materializations to handle cases where - // non-LLVM types persist after an LLVM conversion. - addSourceMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, Location loc) { - if (inputs.size() != 1) - return Value(); + addTargetMaterialization([&](OpBuilder &builder, + LLVM::LLVMStructType resultType, + ValueRange inputs, Location loc, + Type originalType) -> Value { + if (auto memrefType = dyn_cast_or_null<MemRefType>(originalType)) { + if (isBarePointer(inputs)) { + return MemRefDescriptor::fromStaticShape(builder, loc, *this, + memrefType, inputs[0]); + } else if (TypeRange(inputs) == + getMemRefDescriptorFields(memrefType, + /*unpackAggregates=*/true)) { + return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs); + } + } - return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) - .getResult(0); - }); - addTargetMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, Location loc) { - if (inputs.size() != 1) - return Value(); + if (auto memrefType = dyn_cast_or_null<UnrankedMemRefType>(originalType)) { + // Note: Bare pointers are not supported for unranked memrefs because a + // memref descriptor cannot be built just from a bare pointer. + if (TypeRange(inputs) == getUnrankedMemRefDescriptorFields()) + return UnrankedMemRefDescriptor::pack(builder, loc, *this, memrefType, + inputs); + } - return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) - .getResult(0); + return Value(); }); // Integer memory spaces map to themselves. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 1d009b03754c52..81a11e27c26178 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -61,7 +61,6 @@ BufferizeTypeConverter::BufferizeTypeConverter() { addConversion([](UnrankedTensorType type) -> Type { return UnrankedMemRefType::get(type.getElementType(), 0); }); - addArgumentMaterialization(materializeToTensor); addSourceMaterialization(materializeToTensor); addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, ValueRange inputs, Location loc) -> Value { diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp index 0b3a494794f3f5..72c8fd0f324850 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp @@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) { converter.addSourceMaterialization(materializeAsUnrealizedCast); converter.addTargetMaterialization(materializeAsUnrealizedCast); - converter.addArgumentMaterialization(materializeAsUnrealizedCast); } /// Get an unsigned integer or size data type corresponding to \p ty. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index af38485291182f..61bc5022893741 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter { }); addSourceMaterialization(sourceMaterializationCallback); - addArgumentMaterialization(sourceMaterializationCallback); } }; diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp index 61912722662830..71b88d1be1b05b 100644 --- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp @@ -56,7 +56,6 @@ class QuantizedTypeConverter : public TypeConverter { addConversion(convertQuantizedType); addConversion(convertTensorType); - addArgumentMaterialization(materializeConversion); addSourceMaterialization(materializeConversion); addTargetMaterialization(materializeConversion); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp index 834e3634cc130d..8bbb2cac5efdf3 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp @@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { // Required by scf.for 1:N type conversion. addSourceMaterialization(materializeTuple); - - // Required as a workaround until we have full 1:N support. - addArgumentMaterialization(materializeTuple); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 757631944f224f..68535ae5a7a5c6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( return builder.create<vector::ShapeCastOp>(loc, type, inputs.front()); }; - typeConverter.addArgumentMaterialization(materializeCast); typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); target.markUnknownOpDynamicallyLegal( diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d4879c1bc333c3..22891046c78260 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -63,94 +63,266 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) { return OpBuilder::InsertPoint(insertBlock, insertPt); } +/// Helper function that computes an insertion point where the given value is +/// defined and can be used without a dominance violation. +static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) { + assert(!vals.empty() && "expected at least one value"); + OpBuilder::InsertPoint pt = computeInsertPoint(vals.front()); + for (Value v : vals.drop_front()) { + OpBuilder::InsertPoint pt2 = computeInsertPoint(v); + assert(pt.getBlock() == pt2.getBlock()); + if (pt.getPoint() == pt.getBlock()->begin()) { + pt = pt2; + continue; + } + if (pt2.getPoint() == pt2.getBlock()->begin()) { + continue; + } + if (pt.getPoint()->isBeforeInBlock(&*pt2.getPoint())) + pt = pt2; + } + return pt; +} + //===----------------------------------------------------------------------===// // ConversionValueMapping //===----------------------------------------------------------------------===// namespace { +struct SmallVectorMapInfo { + static SmallVector<Value, 1> getEmptyKey() { return SmallVector<Value, 1>{}; } + static SmallVector<Value, 1> getTombstoneKey() { + return SmallVector<Value, 1>{}; + } + static ::llvm::hash_code getHashValue(SmallVector<Value, 1> val) { + return ::llvm::hash_combine_range(val.begin(), val.end()); + } + static bool isEqual(SmallVector<Value, 1> LHS, SmallVector<Value, 1> RHS) { + return LHS == RHS; + } +}; + /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { - /// Lookup the most recently mapped value with the desired type in the - /// mapping. + /// Find the most recently mapped values for the given value. If the value is + /// not mapped at all, return the given value. + SmallVector<Value, 1> lookupOrDefault(Value from) const; + + /// TODO: Find most recently mapped or materialization with matching type. May + /// return the given value if the type matches. + SmallVector<Value, 1> + lookupOrDefault(Value from, SmallVector<Type, 1> desiredTypes) const; + + Value lookupDirectSingleReplacement(Value from) const { + auto it = mapping.find(from); + if (it == mapping.end()) + return Value(); + const SmallVector<Value, 1> &repl = it->second; + if (repl.size() != 1) + return Value(); + return repl.front(); + /* + if (!mapping.contains(from)) return Value(); + auto it = llvm::find(mapping, from); + const SmallVector<Value, 1> &repl = it->second; + if (repl.size() != 1) return Value(); + return repl.front(); + */ + } + + /// Find the most recently mapped values for the given value. If the value is + /// not mapped at all, return an empty vector. + SmallVector<Value, 1> lookupOrNull(Value from) const; + + /// Find the most recently mapped values for the given value. If those values + /// have the desired types, return them. Otherwise, try to find a + /// materialization to the desired types. /// - /// Special cases: - /// - If the desired type is "null", simply return the most recently mapped - /// value. - /// - If there is no mapping to the desired type, also return the most - /// recently mapped value. - /// - If there is no mapping for the given value at all, return the given - /// value. - Value lookupOrDefault(Value from, Type desiredType = nullptr) const; - - /// Lookup a mapped value within the map, or return null if a mapping does not - /// exist. If a mapping exists, this follows the same behavior of - /// `lookupOrDefault`. - Value lookupOrNull(Value from, Type desiredType = nullptr) const; - - /// Map a value to the one provided. - void map(Value oldVal, Value newVal) { - LLVM_DEBUG({ - for (Value it = newVal; it; it = mapping.lookupOrNull(it)) - assert(it != oldVal && "inserting cyclic mapping"); - }); - mapping.map(oldVal, newVal); + /// If the given value is not mapped at all or if there are no mapped values/ + /// materialization results with the desired types, return an empty vector. + SmallVector<Value, 1> lookupOrNull(Value from, + SmallVector<Type, 1> desiredTypes) const; + + Value lookupOrNull(Value from, Type desiredType) { + SmallVector<Value, 1> vals = + lookupOrNull(from, SmallVector<Type, 1>{desiredType}); + if (vals.empty()) + return Value(); + assert(vals.size() == 1 && "expected single value"); + return vals.front(); } - /// Try to map a value to the one provided. Returns false if a transitive - /// mapping from the new value to the old value already exists, true if the - /// map was updated. - bool tryMap(Value oldVal, Value newVal); + void erase(Value from) { mapping.erase(from); } + + void map(Value from, ValueRange to) { +#ifndef NDEBUG + assert(from && "expected non-null value"); + assert(!to.empty() && "cannot map to zero values"); + for (Value v : to) + assert(v && "expected non-null value"); +#endif + // assert(from != to && "cannot map value to itself"); + // TODO: Check for cyclic mapping. + assert(!mapping.contains(from) && "value is already mapped"); + mapping[from].assign(to.begin(), to.end()); + } + + void map(Value from, ArrayRef<BlockArgument> to) { + SmallVector<Value> vals; + for (Value v : to) + vals.push_back(v); + map(from, vals); + } + /* + void map(Value from, ArrayRef<Value> to) { + #ifndef NDEBUG + assert(from && "expected non-null value"); + assert(!to.empty() && "cannot map to zero values"); + for (Value v : to) + assert(v && "expected non-null value"); + #endif + // assert(from != to && "cannot map value to itself"); + // TODO: Check for cyclic mapping. + assert(!mapping.contains(from) && "value is already mapped"); + mapping[from].assign(to.begin(), to.end()); + } + */ + + void mapMaterialization(SmallVector<Value, 1> from, + SmallVector<Value, 1> to) { +#ifndef NDEBUG + assert(!from.empty() && "from cannot be empty"); + assert(!to.empty() && "to cannot be empty"); + for (Value v : from) { + assert(v && "expected non-null value"); + assert(!mapping.contains(v) && + "cannot add materialization for mapped value"); + } + for (Value v : to) { + assert(v && "expected non-null value"); + } + assert(TypeRange(from) != TypeRange(to) && + "cannot add materialization for identical type"); + for (const SmallVector<Value, 1> &mat : materializations[from]) + assert(TypeRange(mat) != TypeRange(to) && + "cannot register duplicate materialization"); +#endif // NDEBUG + materializations[from].push_back(to); + } - /// Drop the last mapping for the given value. - void erase(Value value) { mapping.erase(value); } + void eraseMaterialization(SmallVector<Value, 1> from, + SmallVector<Value, 1> to) { + if (!materializations.count(from)) + return; + auto it = llvm::find(materializations[from], to); + if (it == materializations[from].end()) + return; + if (materializations[from].size() == 1) + materializations.erase(from); + else + materializations[from].erase(it); + } /// Returns the inverse raw value mapping (without recursive query support). DenseMap<Value, SmallVector<Value>> getInverse() const { DenseMap<Value, SmallVector<Value>> inverse; - for (auto &it : mapping.getValueMap()) - inverse[it.second].push_back(it.first); + + for (auto &it : mapping) + for (Value v : it.second) + inverse[v].push_back(it.first); + + for (auto &it : materializations) + for (const SmallVector<Value, 1> &mat : it.second) + for (Value v : mat) + for (Value v2 : it.first) + inverse[v].push_back(v2); + return inverse; } private: - /// Current value mappings. - IRMapping mapping; + /// Replacement mapping: Value -> ValueRange + DenseMap<Value, SmallVector<Value, 1>> mapping; + + /// Materializations: ValueRange -> ValueRange* + DenseMap<SmallVector<Value, 1>, SmallVector<SmallVector<Value, 1>>, + SmallVectorMapInfo> + materializations; }; } // namespace -Value ConversionValueMapping::lookupOrDefault(Value from, - Type desiredType) const { - // Try to find the deepest value that has the desired type. If there is no - // such value, simply return the deepest value. - Value desiredValue; - do { - if (!desiredType || from.getType() == desiredType) - desiredValue = from; - - Value mappedValue = mapping.lookupOrNull(from); - if (!mappedValue) - break; - from = mappedValue; - } while (true); +SmallVector<Value, 1> +ConversionValueMapping::lookupOrDefault(Value from) const { + SmallVector<Value, 1> to = lookupOrNull(from); + return to.empty() ? SmallVector<Value, 1>{from} : to; +} - // If the desired value was found use it, otherwise default to the leaf value. - return desiredValue ? desiredValue : from; +SmallVector<Value, 1> ConversionValueMapping::lookupOrDefault( + Value from, SmallVector<Type, 1> desiredTypes) const { +#ifndef NDEBUG + assert(desiredTypes.size() > 0 && "expected non-empty types"); + for (Type t : desiredTypes) + assert(t && "expected non-null type"); +#endif // NDEBUG + + SmallVector<Value, 1> vals = lookupOrNull(from); + if (vals.empty()) { + // Value is not mapped. Return if the type matches. + if (TypeRange(from) == desiredTypes) + return {from}; + // Check materializations. + auto it = materializations.find({from}); + if (it == materializations.end()) + return {}; + for (const SmallVector<Value, 1> &mat : it->second) + if (TypeRange(mat) == desiredTypes) + return mat; + return {}; + } + + return lookupOrNull(from, desiredTypes); } -Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const { - Value result = lookupOrDefault(from, desiredType); - if (result == from || (desiredType && result.getType() != desiredType)) - return nullptr; +SmallVector<Value, 1> ConversionValueMapping::lookupOrNull(Value from) const { + auto it = mapping.find(from); + if (it == mapping.end()) + return {}; + SmallVector<Value, 1> result; + for (Value v : it->second) { + llvm::append_range(result, lookupOrDefault(v)); + } return result; } -bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) { - for (Value it = newVal; it; it = mapping.lookupOrNull(it)) - if (it == oldVal) - return false; - map(oldVal, newVal); - return true; +SmallVector<Value, 1> +ConversionValueMapping::lookupOrNull(Value from, + SmallVector<Type, 1> desiredTypes) const { +#ifndef NDEBUG + assert(desiredTypes.size() > 0 && "expected non-empty types"); + for (Type t : desiredTypes) + assert(t && "expected non-null type"); +#endif // NDEBUG + + SmallVector<Value, 1> vals = lookupOrNull(from); + if (vals.empty()) + return {}; + + // There is a mapping and the types match. + if (TypeRange(vals) == desiredTypes) + return vals; + + // There is a mapping, but the types do not match. Try to find a matching + // materialization. + auto it = materializations.find(vals); + if (it == materializations.end()) + return {}; + for (const SmallVector<Value, 1> &mat : it->second) + if (TypeRange(mat) == desiredTypes) + return mat; + + // No materialization found. Return an empty vector. + return {}; } //===----------------------------------------------------------------------===// @@ -776,7 +948,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { LogicalResult remapValues(StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVector<SmallVector<Value>> &remapped); + SmallVector<SmallVector<Value, 1>> &remapped); /// Return "true" if the given operation is ignored, and does not need to be /// converted. @@ -832,32 +1004,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { .front(); } - /// Build an N:1 materialization for the given original value that was - /// replaced with the given replacement values. - /// - /// This is a workaround around incomplete 1:N support in the dialect - /// conversion driver. The conversion mapping can store only 1:1 replacements - /// and the conversion patterns only support single Value replacements in the - /// adaptor, so N values must be converted back to a single value. This - /// function will be deleted when full 1:N support has been added. - /// - /// This function inserts an argument materialization back to the original - /// type, followed by a target materialization to the legalized type (if - /// applicable). - void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc, - ValueRange replacements, Value originalValue, - const TypeConverter *converter); - - /// Unpack an N:1 materialization and return the inputs of the - /// materialization. This function unpacks only those materializations that - /// were built with `insertNTo1Materialization`. - /// - /// This is a workaround around incomplete 1:N support in the dialect - /// conversion driver. It allows us to write 1:N conversion patterns while - /// 1:N support is still missing in the conversion value mapping. This - /// function will be deleted when full 1:N support has been added. - SmallVector<Value> unpackNTo1Materialization(Value value); - //===--------------------------------------------------------------------===// // Rewriter Notification Hooks //===--------------------------------------------------------------------===// @@ -1095,12 +1241,9 @@ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( } void UnresolvedMaterializationRewrite::rollback() { - if (getMaterializationKind() == MaterializationKind::Target) { - for (Value input : op->getOperands()) - rewriterImpl.mapping.erase(input); - } + rewriterImpl.mapping.eraseMaterialization(op->getOperands(), + op->getResults()); rewriterImpl.unresolvedMaterializations.erase(getOperation()); - rewriterImpl.nTo1TempMaterializations.erase(getOperation()); op->erase(); } @@ -1144,7 +1287,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { LogicalResult ConversionPatternRewriterImpl::remapValues( StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVector<SmallVector<Value>> &remapped) { + SmallVector<SmallVector<Value, 1>> &remapped) { remapped.reserve(llvm::size(values)); for (const auto &it : llvm::enumerate(values)) { @@ -1152,18 +1295,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( Type origType = operand.getType(); Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); - // Find the most recently mapped value. Unpack all temporary N:1 - // materializations. Such conversions are a workaround around missing - // 1:N support in the ConversionValueMapping. (The conversion patterns - // already support 1:N replacements.) - Value repl = mapping.lookupOrDefault(operand); - SmallVector<Value> unpacked = unpackNTo1Materialization(repl); - if (!currentTypeConverter) { // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply // pass through the most recently mapped value. - remapped.push_back(std::move(unpacked)); + SmallVector<Value, 1> vals = mapping.lookupOrDefault(operand); + remapped.push_back(vals); continue; } @@ -1177,49 +1314,30 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( return failure(); } - // If a type is converted to 0 types, there is nothing to do. + // Try to find a mapped value with the desired type. if (legalTypes.empty()) { remapped.push_back({}); continue; } - if (legalTypes.size() != 1) { - // TODO: This is a 1:N conversion. The conversion value mapping does not - // store such materializations yet. If the types of the most recently - // mapped values do not match, build a target materialization. - if (TypeRange(unpacked) == legalTypes) { - remapped.push_back(std::move(unpacked)); - continue; - } - - // Insert a target materialization if the current pattern expects - // different legalized types. - ValueRange targetMat = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(repl), operandLoc, - /*inputs=*/unpacked, /*outputType=*/legalTypes, - /*originalType=*/origType, currentTypeConverter); - remapped.push_back(targetMat); + SmallVector<Value, 1> mat = mapping.lookupOrDefault(operand, legalTypes); + if (!mat.empty()) { + // Mapped value has the correct type or there is an existing + // materialization. Or the value is not mapped at all and has the + // correct type. + remapped.push_back(mat); continue; } - // Handle 1->1 type conversions. - Type desiredType = legalTypes.front(); - // Try to find a mapped value with the desired type. (Or the operand itself - // if the value is not mapped at all.) - Value newOperand = mapping.lookupOrDefault(operand, desiredType); - if (newOperand.getType() != desiredType) { - // If the looked up value's type does not have the desired type, it means - // that the value was replaced with a value of different type and no - // target materialization was created yet. - Value castValue = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(newOperand), - operandLoc, - /*inputs=*/unpacked, /*outputType=*/desiredType, - /*originalType=*/origType, currentTypeConverter); - mapping.map(newOperand, castValue); - newOperand = castValue; - } - remapped.push_back({newOperand}); + // Create a materialization for the most recently mapped value. + SmallVector<Value, 1> vals = mapping.lookupOrDefault(operand); + ValueRange castValues = buildUnresolvedMaterialization( + MaterializationKind::Target, computeInsertPoint(vals), operandLoc, + /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType, + currentTypeConverter); + + mapping.mapMaterialization(vals, castValues); + remapped.push_back(castValues); } return success(); } @@ -1347,15 +1465,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( continue; } - // This is a 1->1+ mapping. 1->N mappings are not fully supported in the - // dialect conversion. Therefore, we need an argument materialization to - // turn the replacement block arguments into a single SSA value that can be - // used as a replacement. + // Map to replacement arguments. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - insertNTo1Materialization( - OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), - /*replacements=*/replArgs, /*outputValue=*/origArg, converter); + mapping.map(origArg, replArgs); appendRewrite<ReplaceBlockArgRewrite>(block, origArg); } @@ -1398,67 +1511,6 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( return convertOp.getResults(); } -void ConversionPatternRewriterImpl::insertNTo1Materialization( - OpBuilder::InsertPoint ip, Location loc, ValueRange replacements, - Value originalValue, const TypeConverter *converter) { - // Insert argument materialization back to the original type. - Type originalType = originalValue.getType(); - UnrealizedConversionCastOp argCastOp; - Value argMat = buildUnresolvedMaterialization( - MaterializationKind::Argument, ip, loc, - /*inputs=*/replacements, originalType, - /*originalType=*/Type(), converter, &argCastOp); - if (argCastOp) - nTo1TempMaterializations.insert(argCastOp); - mapping.map(originalValue, argMat); - - // Insert target materialization to the legalized type. - Type legalOutputType; - if (converter) { - legalOutputType = converter->convertType(originalType); - } else if (replacements.size() == 1) { - // When there is no type converter, assume that the replacement value - // types are legal. This is reasonable to assume because they were - // specified by the user. - // FIXME: This won't work for 1->N conversions because multiple output - // types are not supported in parts of the dialect conversion. In such a - // case, we currently use the original value type. - legalOutputType = replacements[0].getType(); - } - if (legalOutputType && legalOutputType != originalType) { - UnrealizedConversionCastOp targetCastOp; - Value targetMat = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(argMat), loc, - /*inputs=*/argMat, /*outputType=*/legalOutputType, - /*originalType=*/originalType, converter, &targetCastOp); - if (targetCastOp) - nTo1TempMaterializations.insert(targetCastOp); - mapping.map(argMat, targetMat); - } -} - -SmallVector<Value> -ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) { - // Unpack unrealized_conversion_cast ops that were inserted as a N:1 - // workaround. - auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>(); - if (!castOp) - return {value}; - if (!nTo1TempMaterializations.contains(castOp)) - return {value}; - assert(castOp->getNumResults() == 1 && "expected single result"); - - SmallVector<Value> result; - for (Value v : castOp.getOperands()) { - // Keep unpacking if possible. This is needed because during block - // signature conversions and 1:N op replacements, the driver may have - // inserted two materializations back-to-back: first an argument - // materialization, then a target materialization. - llvm::append_range(result, unpackNTo1Materialization(v)); - } - return result; -} - //===----------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1510,7 +1562,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( result.getLoc(), /*inputs=*/ValueRange(), /*outputType=*/result.getType(), /*originalType=*/Type(), currentTypeConverter); - mapping.map(result, sourceMat); + mapping.map(result, {sourceMat}); continue; } else { // Make sure that the user does not mess with unresolved materializations @@ -1524,18 +1576,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( } // Remap result to replacement value. - if (repl.empty()) - continue; - - if (repl.size() == 1) { - // Single replacement value: replace directly. - mapping.map(result, repl.front()); - } else { - // Multiple replacement values: insert N:1 materialization. - insertNTo1Materialization(computeInsertPoint(result), result.getLoc(), - /*replacements=*/repl, /*outputValue=*/result, - currentTypeConverter); - } + if (!repl.empty()) + mapping.map(result, repl); } appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter); @@ -1614,8 +1656,13 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); SmallVector<ValueRange> newVals; - for (int i = 0; i < newValues.size(); ++i) - newVals.push_back(newValues.slice(i, 1)); + for (int i = 0; i < newValues.size(); ++i) { + if (newValues[i]) { + newVals.push_back(newValues.slice(i, 1)); + } else { + newVals.push_back(ValueRange()); + } + } impl->notifyOpReplaced(op, newVals); } @@ -1682,11 +1729,14 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, << "'(" << from.getOwner()->getParentOp() << ")\n"; }); impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from); - impl->mapping.map(impl->mapping.lookupOrDefault(from), to); + SmallVector<Value, 1> mapped = impl->mapping.lookupOrDefault(from); + assert(mapped.size() == 1 && + "replaceUsesOfBlockArgument is not supported for 1:N replacements"); + impl->mapping.map(mapped.front(), to); } Value ConversionPatternRewriter::getRemappedValue(Value key) { - SmallVector<SmallVector<Value>> remappedValues; + SmallVector<SmallVector<Value, 1>> remappedValues; if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, remappedValues))) return nullptr; @@ -1699,7 +1749,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, SmallVectorImpl<Value> &results) { if (keys.empty()) return success(); - SmallVector<SmallVector<Value>> remapped; + SmallVector<SmallVector<Value, 1>> remapped; if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, remapped))) return failure(); @@ -1825,7 +1875,7 @@ ConversionPattern::matchAndRewrite(Operation *op, getTypeConverter()); // Remap the operands of the operation. - SmallVector<SmallVector<Value>> remapped; + SmallVector<SmallVector<Value, 1>> remapped; if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, op->getOperands(), remapped))) { return failure(); @@ -2582,19 +2632,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, rewriter.setInsertionPoint(op); SmallVector<Value> newMaterialization; switch (rewrite->getMaterializationKind()) { - case MaterializationKind::Argument: { - // Try to materialize an argument conversion. - assert(op->getNumResults() == 1 && "expected single result"); - Value argMat = converter->materializeArgumentConversion( - rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands); - if (argMat) { - newMaterialization.push_back(argMat); - break; - } - } - // If an argument materialization failed, fallback to trying a target - // materialization. - [[fallthrough]]; case MaterializationKind::Target: newMaterialization = converter->materializeTargetConversion( rewriter, op->getLoc(), op.getResultTypes(), inputOperands, @@ -2742,6 +2779,12 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) { std::tie(replacedValues, converter) = getReplacedValues(rewriterImpl.rewrites[i].get()); for (Value originalValue : replacedValues) { + // If this value is directly replaced with a value of the same type, + // there is nothing to do. + Value repl = + rewriterImpl.mapping.lookupDirectSingleReplacement(originalValue); + if (repl && repl.getType() == originalValue.getType()) + continue; // If the type of this value changed and the value is still live, we need // to materialize a conversion. if (rewriterImpl.mapping.lookupOrNull(originalValue, @@ -2753,16 +2796,16 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) { continue; // Legalize this value replacement. - Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue); - assert(newValue && "replacement value not found"); + SmallVector<Value, 1> newValues = + rewriterImpl.mapping.lookupOrNull(originalValue); + assert(!newValues.empty() && "replacement value not found"); Value castValue = rewriterImpl.buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(newValue), + MaterializationKind::Source, computeInsertPoint(newValues), originalValue.getLoc(), - /*inputs=*/newValue, /*outputType=*/originalValue.getType(), + /*inputs=*/newValues, /*outputType=*/originalValue.getType(), /*originalType=*/Type(), converter); - rewriterImpl.mapping.map(originalValue, castValue); - inverseMapping[castValue].push_back(originalValue); - llvm::erase(inverseMapping[newValue], originalValue); + rewriterImpl.mapping.mapMaterialization(newValues, {castValue}); + llvm::append_range(inverseMapping[castValue], newValues); } } } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 3ebee795a251df..3cdb9f80de6d8b 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -64,9 +64,6 @@ func.func @remap_call_1_to_1(%arg0: i64) { // Contents of the old block are moved to the new block. // CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown -// The new block arguments are used in "test.return". -// CHECK-NEXT: notifyOperationModified: test.return - // The old block is erased. // CHECK-NEXT: notifyBlockErased diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp index de511c58ae6ee0..0b8d4c0ee3bb0b 100644 --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -139,7 +139,7 @@ struct TestDecomposeCallGraphTypes tupleType.getFlattenedTypes(types); return success(); }); - typeConverter.addArgumentMaterialization(buildMakeTupleOp); + typeConverter.addSourceMaterialization(buildMakeTupleOp); typeConverter.addTargetMaterialization(buildDecomposeTuple); populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 0b5239168efc43..b719519b8f529e 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1235,7 +1235,6 @@ struct TestTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; TestTypeConverter() { addConversion(convertType); - addArgumentMaterialization(materializeCast); addSourceMaterialization(materializeCast); } @@ -1267,8 +1266,8 @@ struct TestTypeConverter : public TypeConverter { return success(); } - /// Hook for materializing a conversion. This is necessary because we generate - /// 1->N type mappings. + /// Hook for materializing a conversion. This is necessary because we + /// generate 1->N type mappings. static Value materializeCast(OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); @@ -1340,7 +1339,8 @@ struct TestLegalizePatternDriver // correct error code from conversion driver. target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); - // Expect the type_producer/type_consumer operations to only operate on f64. + // Expect the type_producer/type_consumer operations to only operate on + // f64. target.addDynamicallyLegalOp<TestTypeProducerOp>( [](TestTypeProducerOp op) { return op.getType().isF64(); }); target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { @@ -1357,7 +1357,8 @@ struct TestLegalizePatternDriver target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); - // Create a dynamically legal rule that can only be legalized by folding it. + // Create a dynamically legal rule that can only be legalized by folding + // it. target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp index 2cc1fb5d39d788..a03bf0a1023d57 100644 --- a/mlir/test/lib/Transforms/TestDialectConversion.cpp +++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp @@ -28,7 +28,6 @@ namespace { struct PDLLTypeConverter : public TypeConverter { PDLLTypeConverter() { addConversion(convertType); - addArgumentMaterialization(materializeCast); addSourceMaterialization(materializeCast); } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits