https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/121389
The 1:N dialect conversion driver has been deprecated. Use the regular dialect conversion driver instead. This commit deletes the 1:N dialect conversion driver. For details, see https://discourse.llvm.org/t/rfc-merging-1-1-and-1-n-dialect-conversions/82513. I plan to merge this PR in April 2025. Depends on #116524. >From c26bd82d581088e0780507016573c2c7f18f286a Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Tue, 31 Dec 2024 13:32:25 +0100 Subject: [PATCH] [mlir][Transforms] Remove 1:N dialect conversion driver --- .../Func/Transforms/OneToNFuncConversions.h | 26 - .../mlir/Dialect/SCF/Transforms/Patterns.h | 10 - .../SPIRV/Transforms/SPIRVConversion.h | 1 - .../Dialect/SparseTensor/Transforms/Passes.h | 1 - .../mlir/Transforms/DialectConversion.h | 20 - .../mlir/Transforms/OneToNTypeConversion.h | 290 ----------- .../Dialect/Func/Transforms/CMakeLists.txt | 1 - .../Func/Transforms/OneToNFuncConversions.cpp | 87 ---- .../lib/Dialect/SCF/Transforms/CMakeLists.txt | 1 - .../SCF/Transforms/OneToNTypeConversion.cpp | 215 -------- .../SPIRV/Transforms/SPIRVConversion.cpp | 7 +- mlir/lib/Transforms/Utils/CMakeLists.txt | 1 - .../Transforms/Utils/DialectConversion.cpp | 11 - .../Transforms/Utils/OneToNTypeConversion.cpp | 458 ------------------ .../one-to-n-type-conversion.mlir | 140 ------ ...f-structural-one-to-n-type-conversion.mlir | 183 ------- .../decompose-call-graph-types.mlir | 53 -- mlir/test/lib/Conversion/CMakeLists.txt | 1 - .../OneToNTypeConversion/CMakeLists.txt | 21 - .../TestOneToNTypeConversionPass.cpp | 261 ---------- mlir/tools/mlir-opt/CMakeLists.txt | 1 - mlir/tools/mlir-opt/mlir-opt.cpp | 2 - .../llvm-project-overlay/mlir/BUILD.bazel | 2 - .../mlir/test/BUILD.bazel | 17 - 24 files changed, 4 insertions(+), 1806 deletions(-) delete mode 100644 mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h delete mode 100644 mlir/include/mlir/Transforms/OneToNTypeConversion.h delete mode 100644 mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp delete mode 100644 mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp delete mode 100644 mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp delete mode 100644 mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir delete mode 100644 mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir delete mode 100644 mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt delete mode 100644 mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp diff --git a/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h b/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h deleted file mode 100644 index c9e407daf9bf8c..00000000000000 --- a/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h +++ /dev/null @@ -1,26 +0,0 @@ -//===- OneToNTypeFuncConversions.h - 1:N type conv. for Func ----*- C++ -*-===// -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H -#define MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H - -namespace mlir { -class TypeConverter; -class RewritePatternSet; -} // namespace mlir - -namespace mlir { - -// Populates the provided pattern set with patterns that do 1:N type conversions -// on func ops. This is intended to be used with `applyPartialOneToNConversion`. -void populateFuncTypeConversionPatterns(const TypeConverter &typeConverter, - RewritePatternSet &patterns); - -} // namespace mlir - -#endif // MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h index 9c1479d28c305f..00c8a5c0c517b7 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h @@ -63,16 +63,6 @@ void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, void populateSCFStructuralTypeConversionTarget( const TypeConverter &typeConverter, ConversionTarget &target); -/// Populates the provided pattern set with patterns that do 1:N type -/// conversions on (some) SCF ops. This is intended to be used with -/// applyPartialOneToNConversion. -/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon. -/// 1:N support has been added to the regular dialect conversion driver. -LLVM_DEPRECATED("Use populateSCFStructuralTypeConversions() instead", - "populateSCFStructuralTypeConversions") -void populateSCFStructuralOneToNTypeConversions( - const TypeConverter &typeConverter, RewritePatternSet &patterns); - /// Populate patterns for SCF software pipelining transformation. See the /// ForLoopPipeliningPattern for the transformation details. void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index bed4d66ccd6cbe..3d22ec918f4c5f 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -20,7 +20,6 @@ #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/OneToNTypeConversion.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Support/LogicalResult.h" diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index 2e9c297f20182a..acd347c530d58b 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -16,7 +16,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/OneToNTypeConversion.h" //===----------------------------------------------------------------------===// // Include the generated pass header (which needs some early definitions). diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 9a6975dcf8dfae..7e5389a83855a5 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -45,13 +45,11 @@ class TypeConverter { // Copy the registered conversions, but not the caches TypeConverter(const TypeConverter &other) : conversions(other.conversions), - argumentMaterializations(other.argumentMaterializations), sourceMaterializations(other.sourceMaterializations), targetMaterializations(other.targetMaterializations), typeAttributeConversions(other.typeAttributeConversions) {} TypeConverter &operator=(const TypeConverter &other) { conversions = other.conversions; - argumentMaterializations = other.argumentMaterializations; sourceMaterializations = other.sourceMaterializations; targetMaterializations = other.targetMaterializations; typeAttributeConversions = other.typeAttributeConversions; @@ -177,21 +175,6 @@ class TypeConverter { /// can be a TypeRange; in that case, the function must return a /// SmallVector<Value>. - /// This method registers a materialization that will be called when - /// converting (potentially multiple) block arguments that were the result of - /// a signature conversion of a single block argument, to a single SSA value - /// with the old block argument type. - /// - /// Note: Argument materializations are used only with the 1:N dialect - /// conversion driver. The 1:N dialect conversion driver will be removed soon - /// and so will be argument materializations. - template <typename FnT, typename T = typename llvm::function_traits< - std::decay_t<FnT>>::template arg_t<1>> - void addArgumentMaterialization(FnT &&callback) { - argumentMaterializations.emplace_back( - wrapMaterialization<T>(std::forward<FnT>(callback))); - } - /// This method registers a materialization that will be called when /// converting a replacement value back to its original source type. /// This is used when some uses of the original value persist beyond the main @@ -319,8 +302,6 @@ class TypeConverter { /// generating a cast sequence of some kind. See the respective /// `add*Materialization` for more information on the context for these /// methods. - Value materializeArgumentConversion(OpBuilder &builder, Location loc, - Type resultType, ValueRange inputs) const; Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const; Value materializeTargetConversion(OpBuilder &builder, Location loc, @@ -507,7 +488,6 @@ class TypeConverter { SmallVector<ConversionCallbackFn, 4> conversions; /// The list of registered materialization functions. - SmallVector<MaterializationCallbackFn, 2> argumentMaterializations; SmallVector<MaterializationCallbackFn, 2> sourceMaterializations; SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations; diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h deleted file mode 100644 index 9c74bf916d971b..00000000000000 --- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h +++ /dev/null @@ -1,290 +0,0 @@ -//===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===// -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Note: The 1:N dialect conversion is deprecated and will be removed soon. -// 1:N support has been added to the regular dialect conversion driver. -// -// This file provides utils for implementing (poor-man's) dialect conversion -// passes with 1:N type conversions. -// -// The main function, `applyPartialOneToNConversion`, first applies a set of -// `RewritePattern`s, which produce unrealized casts to convert the operands and -// results from and to the source types, and then replaces all newly added -// unrealized casts by user-provided materializations. For this to work, the -// main function requires a special `TypeConverter`, a special -// `PatternRewriter`, and special RewritePattern`s, which extend their -// respective base classes for 1:N type converions. -// -// Note that this is much more simple-minded than the "real" dialect conversion, -// which checks for legality before applying patterns and does probably many -// other additional things. Ideally, some of the extensions here could be -// integrated there. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H -#define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H - -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/SmallVector.h" - -namespace mlir { - -/// Stores a 1:N mapping of types and provides several useful accessors. This -/// class extends `SignatureConversion`, which already supports 1:N type -/// mappings but lacks some accessors into the mapping as well as access to the -/// original types. -class OneToNTypeMapping : public TypeConverter::SignatureConversion { -public: - OneToNTypeMapping(TypeRange originalTypes) - : TypeConverter::SignatureConversion(originalTypes.size()), - originalTypes(originalTypes) {} - - using TypeConverter::SignatureConversion::getConvertedTypes; - - /// Returns the list of types that corresponds to the original type at the - /// given index. - TypeRange getConvertedTypes(unsigned originalTypeNo) const; - - /// Returns the list of original types. - TypeRange getOriginalTypes() const { return originalTypes; } - - /// Returns the slice of converted values that corresponds the original value - /// at the given index. - ValueRange getConvertedValues(ValueRange convertedValues, - unsigned originalValueNo) const; - - /// Fills the given result vector with as many copies of the location of the - /// original value as the number of values it is converted to. - void convertLocation(Value originalValue, unsigned originalValueNo, - llvm::SmallVectorImpl<Location> &result) const; - - /// Fills the given result vector with as many copies of the lociation of each - /// original value as the number of values they are respectively converted to. - void convertLocations(ValueRange originalValues, - llvm::SmallVectorImpl<Location> &result) const; - - /// Returns true iff at least one type conversion maps an input type to a type - /// that is different from itself. - bool hasNonIdentityConversion() const; - -private: - llvm::SmallVector<Type> originalTypes; -}; - -/// Extends the basic `RewritePattern` class with a type converter member and -/// some accessors to it. This is useful for patterns that are not -/// `ConversionPattern`s but still require access to a type converter. -class RewritePatternWithConverter : public mlir::RewritePattern { -public: - /// Construct a conversion pattern with the given converter, and forward the - /// remaining arguments to RewritePattern. - template <typename... Args> - RewritePatternWithConverter(const TypeConverter &typeConverter, - Args &&...args) - : RewritePattern(std::forward<Args>(args)...), - typeConverter(&typeConverter) {} - - /// Return the type converter held by this pattern, or nullptr if the pattern - /// does not require type conversion. - const TypeConverter *getTypeConverter() const { return typeConverter; } - - template <typename ConverterTy> - std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value, - const ConverterTy *> - getTypeConverter() const { - return static_cast<const ConverterTy *>(typeConverter); - } - -protected: - /// A type converter for use by this pattern. - const TypeConverter *const typeConverter; -}; - -/// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The -/// class provides additional rewrite methods that are specific to 1:N type -/// conversions. -class OneToNPatternRewriter : public PatternRewriter { -public: - OneToNPatternRewriter(MLIRContext *context, - OpBuilder::Listener *listener = nullptr) - : PatternRewriter(context, listener) {} - - /// Replaces the results of the operation with the specified list of values - /// mapped back to the original types as specified in the provided type - /// mapping. That type mapping must match the replaced op (i.e., the original - /// types must be the same as the result types of the op) and the new values - /// (i.e., the converted types must be the same as the types of the new - /// values). - /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon. - /// 1:N support has been added to the regular dialect conversion driver. - LLVM_DEPRECATED("Use replaceOpWithMultiple() instead", - "replaceOpWithMultiple") - void replaceOp(Operation *op, ValueRange newValues, - const OneToNTypeMapping &resultMapping); - using PatternRewriter::replaceOp; - - /// Applies the given argument conversion to the given block. This consists of - /// replacing each original argument with N arguments as specified in the - /// argument conversion and inserting unrealized casts from the converted - /// values to the original types, which are then used in lieu of the original - /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts - /// with a user-provided argument materialization if necessary.) This is - /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N - /// type conversion properly and probably (2) doesn't handle many other edge - /// cases. - Block *applySignatureConversion(Block *block, - OneToNTypeMapping &argumentConversion); -}; - -/// Base class for patterns with 1:N type conversions. Derived classes have to -/// overwrite the `matchAndRewrite` overlaod that provides additional -/// information for 1:N type conversions. -class OneToNConversionPattern : public RewritePatternWithConverter { -public: - using RewritePatternWithConverter::RewritePatternWithConverter; - - /// This function has to be implemented by derived classes and is called from - /// the usual overloads. Like in "normal" `DialectConversion`, the function is - /// provided with the converted operands (which thus have target types). Since - /// 1:N conversions are supported, there is usually no 1:1 relationship - /// between the original and the converted operands. Instead, the provided - /// `operandMapping` can be used to access the converted operands that - /// correspond to a particular original operand. Similarly, `resultMapping` - /// is provided to help with assembling the result values, which may have 1:N - /// correspondences as well. In that case, the original op should be replaced - /// with the overload of `replaceOp` that takes the provided `resultMapping` - /// in order to deal with the mapping of converted result values to their - /// usages in the original types correctly. - virtual LogicalResult matchAndRewrite(Operation *op, - OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - ValueRange convertedOperands) const = 0; - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const final; -}; - -/// This class is a wrapper around `OneToNConversionPattern` for matching -/// against instances of a particular op class. -template <typename SourceOp> -class OneToNOpConversionPattern : public OneToNConversionPattern { -public: - OneToNOpConversionPattern(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1, - ArrayRef<StringRef> generatedNames = {}) - : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(), - benefit, context, generatedNames) {} - /// Generic adaptor around the root op of this pattern using the converted - /// operands. Importantly, each operand is represented as a *range* of values, - /// namely the N values each original operand gets converted to. Concretely, - /// this makes the result type of the accessor functions of the adaptor class - /// be a `ValueRange`. - class OpAdaptor - : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> { - public: - using RangeT = ArrayRef<ValueRange>; - using BaseT = typename SourceOp::template GenericAdaptor<RangeT>; - using Properties = typename SourceOp::template InferredProperties<SourceOp>; - - OpAdaptor(const OneToNTypeMapping *operandMapping, - const OneToNTypeMapping *resultMapping, - const ValueRange *convertedOperands, RangeT values, SourceOp op) - : BaseT(values, op), operandMapping(operandMapping), - resultMapping(resultMapping), convertedOperands(convertedOperands) {} - - /// Get the type mapping of the original operands to the converted operands. - const OneToNTypeMapping &getOperandMapping() const { - return *operandMapping; - } - - /// Get the type mapping of the original results to the converted results. - const OneToNTypeMapping &getResultMapping() const { return *resultMapping; } - - /// Get a flat range of all converted operands. Unlike `getOperands`, which - /// returns an `ArrayRef` with one `ValueRange` for each original operand, - /// this function returns a `ValueRange` that contains all converted - /// operands irrespectively of which operand they originated from. - ValueRange getFlatOperands() const { return *convertedOperands; } - - private: - const OneToNTypeMapping *operandMapping; - const OneToNTypeMapping *resultMapping; - const ValueRange *convertedOperands; - }; - - using OneToNConversionPattern::matchAndRewrite; - - /// Overload that derived classes have to override for their op type. - virtual LogicalResult - matchAndRewrite(SourceOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const = 0; - - LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - ValueRange convertedOperands) const final { - // Wrap converted operands and type mappings into an adaptor. - SmallVector<ValueRange> valueRanges; - for (int64_t i = 0; i < op->getNumOperands(); i++) { - auto values = operandMapping.getConvertedValues(convertedOperands, i); - valueRanges.push_back(values); - } - OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands, - valueRanges, cast<SourceOp>(op)); - - // Call overload implemented by the derived class. - return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter); - } -}; - -/// Applies the given set of patterns recursively on the given op and adds user -/// materializations where necessary. The patterns are expected to be -/// `OneToNConversionPattern`, which help converting the types of the operands -/// and results of the matched ops. The provided type converter is used to -/// convert the operands of matched ops from their original types to operands -/// with different types. Unlike in `DialectConversion`, this supports 1:N type -/// conversions. Those conversions at the "boundary" of the pattern application, -/// where converted results are not consumed by replaced ops that expect the -/// converted operands or vice versa, the function inserts user materializations -/// from the type converter. Also unlike `DialectConversion`, there are no legal -/// or illegal types; the function simply applies the given patterns and does -/// not fail if some ops or types remain unconverted (i.e., the conversion is -/// only "partial"). -/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon. -/// 1:N support has been added to the regular dialect conversion driver. -LLVM_DEPRECATED("Use applyPartialConversion() instead", - "applyPartialConversion") -LogicalResult -applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, - const FrozenRewritePatternSet &patterns); - -/// Add a pattern to the given pattern list to convert the signature of a -/// FunctionOpInterface op with the given type converter. This only supports -/// ops which use FunctionType to represent their type. This is intended to be -/// used with the 1:N dialect conversion. -/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon. -/// 1:N support has been added to the regular dialect conversion driver. -LLVM_DEPRECATED( - "Use populateFunctionOpInterfaceTypeConversionPattern() instead", - "populateFunctionOpInterfaceTypeConversionPattern") -void populateOneToNFunctionOpInterfaceTypeConversionPattern( - StringRef functionLikeOpName, const TypeConverter &converter, - RewritePatternSet &patterns); -template <typename FuncOpT> -void populateOneToNFunctionOpInterfaceTypeConversionPattern( - const TypeConverter &converter, RewritePatternSet &patterns) { - populateOneToNFunctionOpInterfaceTypeConversionPattern( - FuncOpT::getOperationName(), converter, patterns); -} - -} // namespace mlir - -#endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt index 6384d25ee70273..0bed59e109503f 100644 --- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt @@ -1,7 +1,6 @@ add_mlir_dialect_library(MLIRFuncTransforms DuplicateFunctionElimination.cpp FuncConversions.cpp - OneToNFuncConversions.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Transforms diff --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp deleted file mode 100644 index 3b8982257a9c95..00000000000000 --- a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp +++ /dev/null @@ -1,87 +0,0 @@ -//===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===// -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// The patterns in this file are heavily inspired (and copied from) -// convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the -// patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N -// type conversions. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Transforms/OneToNTypeConversion.h" - -using namespace mlir; -using namespace mlir::func; - -namespace { - -class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> { -public: - using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern; - - LogicalResult - matchAndRewrite(CallOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); - - // Nothing to do if the op doesn't have any non-identity conversions for its - // operands or results. - if (!adaptor.getOperandMapping().hasNonIdentityConversion() && - !resultMapping.hasNonIdentityConversion()) - return failure(); - - // Create new CallOp. - auto newOp = - rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(), - adaptor.getFlatOperands(), op->getAttrs()); - - rewriter.replaceOp(op, newOp->getResults(), resultMapping); - return success(); - } -}; - -class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> { -public: - using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern; - - LogicalResult - matchAndRewrite(ReturnOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { - // Nothing to do if there is no non-identity conversion. - if (!adaptor.getOperandMapping().hasNonIdentityConversion()) - return failure(); - - // Convert operands. - rewriter.modifyOpInPlace( - op, [&] { op->setOperands(adaptor.getFlatOperands()); }); - - return success(); - } -}; - -} // namespace - -namespace mlir { - -void populateFuncTypeConversionPatterns(const TypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add< - // clang-format off - ConvertTypesInFuncCallOp, - ConvertTypesInFuncReturnOp - // clang-format on - >(typeConverter, patterns.getContext()); - populateOneToNFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( - typeConverter, patterns); -} - -} // namespace mlir diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt index e99b5d0cc26fc7..84dd992bec53a7 100644 --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -8,7 +8,6 @@ add_mlir_dialect_library(MLIRSCFTransforms LoopPipelining.cpp LoopRangeFolding.cpp LoopSpecialization.cpp - OneToNTypeConversion.cpp ParallelLoopCollapsing.cpp ParallelLoopFusion.cpp ParallelLoopTiling.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp deleted file mode 100644 index 4cd17f77dfb941..00000000000000 --- a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp +++ /dev/null @@ -1,215 +0,0 @@ -//===-- OneToNTypeConversion.cpp - SCF 1:N type conversion ------*- C++ -*-===// -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// The patterns in this file are heavily inspired (and copied from) -// lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp but work for 1:N -// type conversions. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/SCF/Transforms/Transforms.h" - -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Transforms/OneToNTypeConversion.h" - -using namespace mlir; -using namespace mlir::scf; - -class ConvertTypesInSCFIfOp : public OneToNOpConversionPattern<IfOp> { -public: - using OneToNOpConversionPattern<IfOp>::OneToNOpConversionPattern; - - LogicalResult - matchAndRewrite(IfOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); - - // Nothing to do if there is no non-identity conversion. - if (!resultMapping.hasNonIdentityConversion()) - return failure(); - - // Create new IfOp. - TypeRange convertedResultTypes = resultMapping.getConvertedTypes(); - auto newOp = rewriter.create<IfOp>(loc, convertedResultTypes, - op.getCondition(), true); - newOp->setAttrs(op->getAttrs()); - - // We do not need the empty blocks created by rewriter. - rewriter.eraseBlock(newOp.elseBlock()); - rewriter.eraseBlock(newOp.thenBlock()); - - // Inlines block from the original operation. - rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), - newOp.getThenRegion().end()); - rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), - newOp.getElseRegion().end()); - - rewriter.replaceOp(op, newOp->getResults(), resultMapping); - return success(); - } -}; - -class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern<WhileOp> { -public: - using OneToNOpConversionPattern<WhileOp>::OneToNOpConversionPattern; - - LogicalResult - matchAndRewrite(WhileOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - - const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); - const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); - - // Nothing to do if the op doesn't have any non-identity conversions for its - // operands or results. - if (!operandMapping.hasNonIdentityConversion() && - !resultMapping.hasNonIdentityConversion()) - return failure(); - - // Create new WhileOp. - TypeRange convertedResultTypes = resultMapping.getConvertedTypes(); - - auto newOp = rewriter.create<WhileOp>(loc, convertedResultTypes, - adaptor.getFlatOperands()); - newOp->setAttrs(op->getAttrs()); - - // Update block signatures. - std::array<OneToNTypeMapping, 2> blockMappings = {operandMapping, - resultMapping}; - for (unsigned int i : {0u, 1u}) { - Region *region = &op.getRegion(i); - Block *block = ®ion->front(); - - rewriter.applySignatureConversion(block, blockMappings[i]); - - // Move updated region to new WhileOp. - Region &dstRegion = newOp.getRegion(i); - rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); - } - - rewriter.replaceOp(op, newOp->getResults(), resultMapping); - return success(); - } -}; - -class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern<YieldOp> { -public: - using OneToNOpConversionPattern<YieldOp>::OneToNOpConversionPattern; - - LogicalResult - matchAndRewrite(YieldOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { - // Nothing to do if there is no non-identity conversion. - if (!adaptor.getOperandMapping().hasNonIdentityConversion()) - return failure(); - - // Convert operands. - rewriter.modifyOpInPlace( - op, [&] { op->setOperands(adaptor.getFlatOperands()); }); - - return success(); - } -}; - -class ConvertTypesInSCFConditionOp - : public OneToNOpConversionPattern<ConditionOp> { -public: - using OneToNOpConversionPattern<ConditionOp>::OneToNOpConversionPattern; - - LogicalResult - matchAndRewrite(ConditionOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { - // Nothing to do if there is no non-identity conversion. - if (!adaptor.getOperandMapping().hasNonIdentityConversion()) - return failure(); - - // Convert operands. - rewriter.modifyOpInPlace( - op, [&] { op->setOperands(adaptor.getFlatOperands()); }); - - return success(); - } -}; - -class ConvertTypesInSCFForOp final : public OneToNOpConversionPattern<ForOp> { -public: - using OneToNOpConversionPattern<ForOp>::OneToNOpConversionPattern; - - LogicalResult - matchAndRewrite(ForOp forOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { - const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); - const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); - - // Nothing to do if there is no non-identity conversion. - if (!operandMapping.hasNonIdentityConversion() && - !resultMapping.hasNonIdentityConversion()) - return failure(); - - // If the lower-bound, upper-bound, or step were expanded, abort the - // conversion. This conversion does not know what to do in such cases. - ValueRange lbs = adaptor.getLowerBound(); - ValueRange ubs = adaptor.getUpperBound(); - ValueRange steps = adaptor.getStep(); - if (lbs.size() != 1 || ubs.size() != 1 || steps.size() != 1) - return rewriter.notifyMatchFailure( - forOp, "index operands converted to multiple values"); - - Location loc = forOp.getLoc(); - - Region *region = &forOp.getRegion(); - Block *block = ®ion->front(); - - // Construct the new for-op with an empty body. - ValueRange newInits = adaptor.getFlatOperands().drop_front(3); - auto newOp = - rewriter.create<ForOp>(loc, lbs[0], ubs[0], steps[0], newInits); - newOp->setAttrs(forOp->getAttrs()); - - // We do not need the empty blocks created by rewriter. - rewriter.eraseBlock(newOp.getBody()); - - // Convert the signature of the body region. - OneToNTypeMapping bodyTypeMapping(block->getArgumentTypes()); - if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), - bodyTypeMapping))) - return failure(); - - // Perform signature conversion on the body block. - rewriter.applySignatureConversion(block, bodyTypeMapping); - - // Splice the old body region into the new for-op. - Region &dstRegion = newOp.getBodyRegion(); - rewriter.inlineRegionBefore(forOp.getRegion(), dstRegion, dstRegion.end()); - - rewriter.replaceOp(forOp, newOp.getResults(), resultMapping); - - return success(); - } -}; - -namespace mlir { -namespace scf { - -void populateSCFStructuralOneToNTypeConversions( - const TypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add< - // clang-format off - ConvertTypesInSCFConditionOp, - ConvertTypesInSCFForOp, - ConvertTypesInSCFIfOp, - ConvertTypesInSCFWhileOp, - ConvertTypesInSCFYieldOp - // clang-format on - >(typeConverter, patterns.getContext()); -} - -} // namespace scf -} // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 29f7e8afe0773b..d837e305c4c34e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -29,7 +29,6 @@ #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/OneToNTypeConversion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" @@ -933,7 +932,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&entryBlock); - OneToNTypeMapping oneToNTypeMapping(fnType.getInputs()); + TypeConverter::SignatureConversion oneToNTypeMapping( + fnType.getInputs().size()); // For arguments that are of illegal types and require unrolling. // `unrolledInputNums` stores the indices of arguments that result from @@ -1073,7 +1073,8 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> { return failure(); FunctionType fnType = funcOp.getFunctionType(); - OneToNTypeMapping oneToNTypeMapping(fnType.getResults()); + TypeConverter::SignatureConversion oneToNTypeMapping( + fnType.getResults().size()); Location loc = returnOp.getLoc(); // For the new return op. diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt index 72eb34f36cf5f6..3ca16239ba33c0 100644 --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -8,7 +8,6 @@ add_mlir_library(MLIRTransformUtils Inliner.cpp InliningUtils.cpp LoopInvariantCodeMotionUtils.cpp - OneToNTypeConversion.cpp RegionUtils.cpp WalkPatternRewriteDriver.cpp diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 3571e017158be9..616921b39cd8cd 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2951,17 +2951,6 @@ TypeConverter::convertSignatureArgs(TypeRange types, return success(); } -Value TypeConverter::materializeArgumentConversion(OpBuilder &builder, - Location loc, - Type resultType, - ValueRange inputs) const { - for (const MaterializationCallbackFn &fn : - llvm::reverse(argumentMaterializations)) - if (Value result = fn(builder, resultType, inputs, loc)) - return result; - return nullptr; -} - Value TypeConverter::materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const { diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp deleted file mode 100644 index 6474c59595eb43..00000000000000 --- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp +++ /dev/null @@ -1,458 +0,0 @@ -//===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===// -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Transforms/OneToNTypeConversion.h" - -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/SmallSet.h" - -#include <unordered_map> - -using namespace llvm; -using namespace mlir; - -TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const { - TypeRange convertedTypes = getConvertedTypes(); - if (auto mapping = getInputMapping(originalTypeNo)) - return convertedTypes.slice(mapping->inputNo, mapping->size); - return {}; -} - -ValueRange -OneToNTypeMapping::getConvertedValues(ValueRange convertedValues, - unsigned originalValueNo) const { - if (auto mapping = getInputMapping(originalValueNo)) - return convertedValues.slice(mapping->inputNo, mapping->size); - return {}; -} - -void OneToNTypeMapping::convertLocation( - Value originalValue, unsigned originalValueNo, - llvm::SmallVectorImpl<Location> &result) const { - if (auto mapping = getInputMapping(originalValueNo)) - result.append(mapping->size, originalValue.getLoc()); -} - -void OneToNTypeMapping::convertLocations( - ValueRange originalValues, llvm::SmallVectorImpl<Location> &result) const { - assert(originalValues.size() == getOriginalTypes().size()); - for (auto [i, value] : llvm::enumerate(originalValues)) - convertLocation(value, i, result); -} - -static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) { - return convertedTypes.size() == 1 && convertedTypes[0] == originalType; -} - -bool OneToNTypeMapping::hasNonIdentityConversion() const { - // XXX: I think that the original types and the converted types are the same - // iff there was no non-identity type conversion. If that is true, the - // patterns could actually test whether there is anything useful to do - // without having access to the signature conversion. - for (auto [i, originalType] : llvm::enumerate(originalTypes)) { - TypeRange types = getConvertedTypes(i); - if (!isIdentityConversion(originalType, types)) { - assert(TypeRange(originalTypes) != getConvertedTypes()); - return true; - } - } - assert(TypeRange(originalTypes) == getConvertedTypes()); - return false; -} - -namespace { -enum class CastKind { - // Casts block arguments in the target type back to the source type. (If - // necessary, this cast becomes an argument materialization.) - Argument, - - // Casts other values in the target type back to the source type. (If - // necessary, this cast becomes a source materialization.) - Source, - - // Casts values in the source type to the target type. (If necessary, this - // cast becomes a target materialization.) - Target -}; -} // namespace - -/// Mapping of enum values to string values. -StringRef getCastKindName(CastKind kind) { - static const std::unordered_map<CastKind, StringRef> castKindNames = { - {CastKind::Argument, "argument"}, - {CastKind::Source, "source"}, - {CastKind::Target, "target"}}; - return castKindNames.at(kind); -} - -/// Attribute name that is used to annotate inserted unrealized casts with their -/// kind (source, argument, or target). -static const char *const castKindAttrName = - "__one-to-n-type-conversion_cast-kind__"; - -/// Builds an `UnrealizedConversionCastOp` from the given inputs to the given -/// result types. Returns the result values of the cast. -static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes, - ValueRange inputs, CastKind kind) { - // Special case: 1-to-N conversion with N = 0. No need to build an - // UnrealizedConversionCastOp because the op will always be dead. - if (resultTypes.empty()) - return ValueRange(); - - // Create cast. - Location loc = builder.getUnknownLoc(); - if (!inputs.empty()) - loc = inputs.front().getLoc(); - auto castOp = - builder.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs); - - // Store cast kind as attribute. - auto kindAttr = StringAttr::get(builder.getContext(), getCastKindName(kind)); - castOp->setAttr(castKindAttrName, kindAttr); - - return castOp->getResults(); -} - -/// Builds one `UnrealizedConversionCastOp` for each of the given original -/// values using the respective target types given in the provided conversion -/// mapping and returns the results of these casts. If the conversion mapping of -/// a value maps a type to itself (i.e., is an identity conversion), then no -/// cast is inserted and the original value is returned instead. -/// Note that these unrealized casts are different from target materializations -/// in that they are *always* inserted, even if they immediately fold away, such -/// that patterns always see valid intermediate IR, whereas materializations are -/// only used in the places where the unrealized casts *don't* fold away. -static SmallVector<Value> -buildUnrealizedForwardCasts(ValueRange originalValues, - OneToNTypeMapping &conversion, - RewriterBase &rewriter, CastKind kind) { - - // Convert each operand one by one. - SmallVector<Value> convertedValues; - convertedValues.reserve(conversion.getConvertedTypes().size()); - for (auto [idx, originalValue] : llvm::enumerate(originalValues)) { - TypeRange convertedTypes = conversion.getConvertedTypes(idx); - - // Identity conversion: keep operand as is. - if (isIdentityConversion(originalValue.getType(), convertedTypes)) { - convertedValues.push_back(originalValue); - continue; - } - - // Non-identity conversion: materialize target types. - ValueRange castResult = - buildUnrealizedCast(rewriter, convertedTypes, originalValue, kind); - convertedValues.append(castResult.begin(), castResult.end()); - } - - return convertedValues; -} - -/// Builds one `UnrealizedConversionCastOp` for each sequence of the given -/// original values to one value of the type they originated from, i.e., a -/// "reverse" conversion from N converted values back to one value of the -/// original type, using the given (forward) type conversion. If a given value -/// was mapped to a value of the same type (i.e., the conversion in the mapping -/// is an identity conversion), then the "converted" value is returned without -/// cast. -/// Note that these unrealized casts are different from source materializations -/// in that they are *always* inserted, even if they immediately fold away, such -/// that patterns always see valid intermediate IR, whereas materializations are -/// only used in the places where the unrealized casts *don't* fold away. -static SmallVector<Value> -buildUnrealizedBackwardsCasts(ValueRange convertedValues, - const OneToNTypeMapping &typeConversion, - RewriterBase &rewriter) { - assert(typeConversion.getConvertedTypes() == convertedValues.getTypes()); - - // Create unrealized cast op for each converted result of the op. - SmallVector<Value> recastValues; - TypeRange originalTypes = typeConversion.getOriginalTypes(); - recastValues.reserve(originalTypes.size()); - auto convertedValueIt = convertedValues.begin(); - for (auto [idx, originalType] : llvm::enumerate(originalTypes)) { - TypeRange convertedTypes = typeConversion.getConvertedTypes(idx); - size_t numConvertedValues = convertedTypes.size(); - if (isIdentityConversion(originalType, convertedTypes)) { - // Identity conversion: take result as is. - recastValues.push_back(*convertedValueIt); - } else { - // Non-identity conversion: cast back to source type. - ValueRange recastValue = buildUnrealizedCast( - rewriter, originalType, - ValueRange{convertedValueIt, convertedValueIt + numConvertedValues}, - CastKind::Source); - assert(recastValue.size() == 1); - recastValues.push_back(recastValue.front()); - } - convertedValueIt += numConvertedValues; - } - - return recastValues; -} - -void OneToNPatternRewriter::replaceOp(Operation *op, ValueRange newValues, - const OneToNTypeMapping &resultMapping) { - // Create a cast back to the original types and replace the results of the - // original op with those. - assert(newValues.size() == resultMapping.getConvertedTypes().size()); - assert(op->getResultTypes() == resultMapping.getOriginalTypes()); - PatternRewriter::InsertionGuard g(*this); - setInsertionPointAfter(op); - SmallVector<Value> castResults = - buildUnrealizedBackwardsCasts(newValues, resultMapping, *this); - replaceOp(op, castResults); -} - -Block *OneToNPatternRewriter::applySignatureConversion( - Block *block, OneToNTypeMapping &argumentConversion) { - PatternRewriter::InsertionGuard g(*this); - - // Split the block at the beginning to get a new block to use for the - // updated signature. - SmallVector<Location> locs; - argumentConversion.convertLocations(block->getArguments(), locs); - Block *newBlock = - createBlock(block, argumentConversion.getConvertedTypes(), locs); - replaceAllUsesWith(block, newBlock); - - // Create necessary casts in new block. - SmallVector<Value> castResults; - for (auto [i, arg] : llvm::enumerate(block->getArguments())) { - TypeRange convertedTypes = argumentConversion.getConvertedTypes(i); - ValueRange newArgs = - argumentConversion.getConvertedValues(newBlock->getArguments(), i); - if (isIdentityConversion(arg.getType(), convertedTypes)) { - // Identity conversion: take argument as is. - assert(newArgs.size() == 1); - castResults.push_back(newArgs.front()); - } else { - // Non-identity conversion: cast the converted arguments to the original - // type. - PatternRewriter::InsertionGuard g(*this); - setInsertionPointToStart(newBlock); - ValueRange castResult = buildUnrealizedCast(*this, arg.getType(), newArgs, - CastKind::Argument); - assert(castResult.size() == 1); - castResults.push_back(castResult.front()); - } - } - - // Merge old block into new block such that we only have the latter with the - // new signature. - mergeBlocks(block, newBlock, castResults); - - return newBlock; -} - -LogicalResult -OneToNConversionPattern::matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const { - auto *typeConverter = getTypeConverter(); - - // Construct conversion mapping for results. - Operation::result_type_range originalResultTypes = op->getResultTypes(); - OneToNTypeMapping resultMapping(originalResultTypes); - if (failed(typeConverter->convertSignatureArgs(originalResultTypes, - resultMapping))) - return failure(); - - // Construct conversion mapping for operands. - Operation::operand_type_range originalOperandTypes = op->getOperandTypes(); - OneToNTypeMapping operandMapping(originalOperandTypes); - if (failed(typeConverter->convertSignatureArgs(originalOperandTypes, - operandMapping))) - return failure(); - - // Cast operands to target types. - SmallVector<Value> convertedOperands = buildUnrealizedForwardCasts( - op->getOperands(), operandMapping, rewriter, CastKind::Target); - - // Create a `OneToNPatternRewriter` for the pattern, which provides additional - // functionality. - // TODO(ingomueller): I guess it would be better to use only one rewriter - // throughout the whole pass, but that would require to - // drive the pattern application ourselves, which is a lot - // of additional boilerplate code. This seems to work fine, - // so I leave it like this for the time being. - OneToNPatternRewriter oneToNPatternRewriter(rewriter.getContext(), - rewriter.getListener()); - oneToNPatternRewriter.restoreInsertionPoint(rewriter.saveInsertionPoint()); - - // Apply actual pattern. - if (failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping, - resultMapping, convertedOperands))) - return failure(); - - return success(); -} - -namespace mlir { - -// This function applies the provided patterns using -// `applyPatternsGreedily` and then replaces all newly inserted -// `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts -// from target to source types inserted by a `OneToNConversionPattern` normally -// fold away with the "forward" casts from source to target types inserted by -// the next pattern.) To understand which casts are "newly inserted", all casts -// inserted by this pass are annotated with a string attribute that also -// documents which kind of the cast (source, argument, or target). -LogicalResult -applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, - const FrozenRewritePatternSet &patterns) { -#ifndef NDEBUG - // Remember existing unrealized casts. This data structure is only used in - // asserts; building it only for that purpose may be an overkill. - SmallSet<UnrealizedConversionCastOp, 4> existingCasts; - op->walk([&](UnrealizedConversionCastOp castOp) { - assert(!castOp->hasAttr(castKindAttrName)); - existingCasts.insert(castOp); - }); -#endif // NDEBUG - - // Apply provided conversion patterns. - if (failed(applyPatternsGreedily(op, patterns))) { - emitError(op->getLoc()) << "failed to apply conversion patterns"; - return failure(); - } - - // Find all unrealized casts inserted by the pass that haven't folded away. - SmallVector<UnrealizedConversionCastOp> worklist; - op->walk([&](UnrealizedConversionCastOp castOp) { - if (castOp->hasAttr(castKindAttrName)) { - assert(!existingCasts.contains(castOp)); - worklist.push_back(castOp); - } - }); - - // Replace new casts with user materializations. - IRRewriter rewriter(op->getContext()); - for (UnrealizedConversionCastOp castOp : worklist) { - TypeRange resultTypes = castOp->getResultTypes(); - ValueRange operands = castOp->getOperands(); - StringRef castKind = - castOp->getAttrOfType<StringAttr>(castKindAttrName).getValue(); - rewriter.setInsertionPoint(castOp); - -#ifndef NDEBUG - // Determine whether operands or results are already legal to test some - // assumptions for the different kind of materializations. These properties - // are only used it asserts and it may be overkill to compute them. - bool areOperandTypesLegal = llvm::all_of( - operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); }); - bool areResultsTypesLegal = llvm::all_of( - resultTypes, [&](Type t) { return typeConverter.isLegal(t); }); -#endif // NDEBUG - - // Add materialization and remember materialized results. - SmallVector<Value> materializedResults; - if (castKind == getCastKindName(CastKind::Target)) { - // Target materialization. - assert(!areOperandTypesLegal && areResultsTypesLegal && - operands.size() == 1 && "found unexpected target cast"); - materializedResults = typeConverter.materializeTargetConversion( - rewriter, castOp->getLoc(), resultTypes, operands.front()); - if (materializedResults.empty()) { - emitError(castOp->getLoc()) - << "failed to create target materialization"; - return failure(); - } - } else { - // Source and argument materializations. - assert(areOperandTypesLegal && !areResultsTypesLegal && - resultTypes.size() == 1 && "found unexpected cast"); - std::optional<Value> maybeResult; - if (castKind == getCastKindName(CastKind::Source)) { - // Source materialization. - maybeResult = typeConverter.materializeSourceConversion( - rewriter, castOp->getLoc(), resultTypes.front(), - castOp.getOperands()); - } else { - // Argument materialization. - assert(castKind == getCastKindName(CastKind::Argument) && - "unexpected value of cast kind attribute"); - assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>)); - maybeResult = typeConverter.materializeArgumentConversion( - rewriter, castOp->getLoc(), resultTypes.front(), - castOp.getOperands()); - } - if (!maybeResult.has_value() || !maybeResult.value()) { - emitError(castOp->getLoc()) - << "failed to create " << castKind << " materialization"; - return failure(); - } - materializedResults = {maybeResult.value()}; - } - - // Replace the cast with the result of the materialization. - rewriter.replaceOp(castOp, materializedResults); - } - - return success(); -} - -namespace { -class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern { -public: - FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName, - MLIRContext *ctx, - const TypeConverter &converter) - : OneToNConversionPattern(converter, functionLikeOpName, /*benefit=*/1, - ctx) {} - - LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - ValueRange convertedOperands) const override { - auto funcOp = cast<FunctionOpInterface>(op); - auto *typeConverter = getTypeConverter(); - - // Construct mapping for function arguments. - OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes()); - if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(), - argumentMapping))) - return failure(); - - // Construct mapping for function results. - OneToNTypeMapping funcResultMapping(funcOp.getResultTypes()); - if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(), - funcResultMapping))) - return failure(); - - // Nothing to do if the op doesn't have any non-identity conversions for its - // operands or results. - if (!argumentMapping.hasNonIdentityConversion() && - !funcResultMapping.hasNonIdentityConversion()) - return failure(); - - // Update the function signature in-place. - auto newType = FunctionType::get(rewriter.getContext(), - argumentMapping.getConvertedTypes(), - funcResultMapping.getConvertedTypes()); - rewriter.modifyOpInPlace(op, [&] { funcOp.setType(newType); }); - - // Update block signatures. - if (!funcOp.isExternal()) { - Region *region = &funcOp.getFunctionBody(); - Block *block = ®ion->front(); - rewriter.applySignatureConversion(block, argumentMapping); - } - - return success(); - } -}; -} // namespace - -void populateOneToNFunctionOpInterfaceTypeConversionPattern( - StringRef functionLikeOpName, const TypeConverter &converter, - RewritePatternSet &patterns) { - patterns.add<FunctionOpInterfaceSignatureConversion>( - functionLikeOpName, patterns.getContext(), converter); -} -} // namespace mlir diff --git a/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir deleted file mode 100644 index 611ec0265cd37b..00000000000000 --- a/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir +++ /dev/null @@ -1,140 +0,0 @@ -// RUN: mlir-opt %s -split-input-file \ -// RUN: -test-one-to-n-type-conversion="convert-tuple-ops" \ -// RUN: | FileCheck --check-prefix=CHECK-TUP %s - -// RUN: mlir-opt %s -split-input-file \ -// RUN: -test-one-to-n-type-conversion="convert-func-ops" \ -// RUN: | FileCheck --check-prefix=CHECK-FUNC %s - -// RUN: mlir-opt %s -split-input-file \ -// RUN: -test-one-to-n-type-conversion="convert-func-ops convert-tuple-ops" \ -// RUN: | FileCheck --check-prefix=CHECK-BOTH %s - -// Test case: Matching nested packs and unpacks just disappear. - -// CHECK-TUP-LABEL: func.func @pack_unpack( -// CHECK-TUP-SAME: %[[ARG0:.*]]: i1, -// CHECK-TUP-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { -// CHECK-TUP-DAG: return %[[ARG0]], %[[ARG1]] : i1, i2 -func.func @pack_unpack(%arg0: i1, %arg1: i2) -> (i1, i2) { - %0 = "test.make_tuple"() : () -> tuple<> - %1 = "test.make_tuple"(%arg1) : (i2) -> tuple<i2> - %2 = "test.make_tuple"(%1) : (tuple<i2>) -> tuple<tuple<i2>> - %3 = "test.make_tuple"(%0, %arg0, %2) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>> - %4 = "test.get_tuple_element"(%3) {index = 0 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<> - %5 = "test.get_tuple_element"(%3) {index = 1 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1 - %6 = "test.get_tuple_element"(%3) {index = 2 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>> - %7 = "test.get_tuple_element"(%6) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2> - %8 = "test.get_tuple_element"(%7) {index = 0 : i32} : (tuple<i2>) -> i2 - return %5, %8 : i1, i2 -} - -// ----- - -// Test case: Appropriate materializations are created depending on which ops -// are converted. - -// If we only convert the tuple ops, the original `get_tuple_element` ops will -// disappear but one target materialization will be inserted from the -// unconverted function arguments to each of the return values (which have -// redundancy among themselves). -// -// CHECK-TUP-LABEL: func.func @materializations_tuple_args( -// CHECK-TUP-SAME: %[[ARG0:.*]]: tuple<tuple<>, i1, tuple<tuple<i2>>>) -> (i1, i2) { -// CHECK-TUP-DAG: %[[V0:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<> -// CHECK-TUP-DAG: %[[V1:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1 -// CHECK-TUP-DAG: %[[V2:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>> -// CHECK-TUP-DAG: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2> -// CHECK-TUP-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<i2>) -> i2 -// CHECK-TUP-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<> -// CHECK-TUP-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1 -// CHECK-TUP-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>> -// CHECK-TUP-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2> -// CHECK-TUP-DAG: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<i2>) -> i2 -// CHECK-TUP-DAG: return %[[V1]], %[[V9]] : i1, i2 - -// If we only convert the func ops, argument materializations are created from -// the converted tuple elements back to the tuples that the `get_tuple_element` -// ops expect. -// -// CHECK-FUNC-LABEL: func.func @materializations_tuple_args( -// CHECK-FUNC-SAME: %[[ARG0:.*]]: i1, -// CHECK-FUNC-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { -// CHECK-FUNC-DAG: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<> -// CHECK-FUNC-DAG: %[[V1:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2> -// CHECK-FUNC-DAG: %[[V2:.*]] = "test.make_tuple"(%[[V1]]) : (tuple<i2>) -> tuple<tuple<i2>> -// CHECK-FUNC-DAG: %[[V3:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]], %[[V2]]) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>> -// CHECK-FUNC-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<> -// CHECK-FUNC-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1 -// CHECK-FUNC-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>> -// CHECK-FUNC-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2> -// CHECK-FUNC-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) <{index = 0 : i32}> : (tuple<i2>) -> i2 -// CHECK-FUNC-DAG: return %[[V5]], %[[V8]] : i1, i2 - -// If we convert both tuple and func ops, basically everything disappears. -// -// CHECK-BOTH-LABEL: func.func @materializations_tuple_args( -// CHECK-BOTH-SAME: %[[ARG0:.*]]: i1, -// CHECK-BOTH-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { -// CHECK-BOTH-DAG: return %[[ARG0]], %[[ARG1]] : i1, i2 - -func.func @materializations_tuple_args(%arg0: tuple<tuple<>, i1, tuple<tuple<i2>>>) -> (i1, i2) { - %0 = "test.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<> - %1 = "test.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1 - %2 = "test.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>> - %3 = "test.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2> - %4 = "test.get_tuple_element"(%3) {index = 0 : i32} : (tuple<i2>) -> i2 - return %1, %4 : i1, i2 -} -// ----- - -// Test case: Appropriate materializations are created depending on which ops -// are converted. - -// If we only convert the tuple ops, the original `make_tuple` ops will -// disappear but a source materialization will be inserted from the result of -// conversion (which, for `make_tuple`, are the original ops that get forwarded) -// to the operands of the unconverted op with the original type (i.e., -// `return`). - -// CHECK-TUP-LABEL: func.func @materializations_tuple_return( -// CHECK-TUP-SAME: %[[ARG0:.*]]: i1, -// CHECK-TUP-SAME: %[[ARG1:.*]]: i2) -> tuple<tuple<>, i1, tuple<tuple<i2>>> { -// CHECK-TUP-DAG: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<> -// CHECK-TUP-DAG: %[[V1:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2> -// CHECK-TUP-DAG: %[[V2:.*]] = "test.make_tuple"(%[[V1]]) : (tuple<i2>) -> tuple<tuple<i2>> -// CHECK-TUP-DAG: %[[V3:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]], %[[V2]]) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>> -// CHECK-TUP-DAG: return %[[V3]] : tuple<tuple<>, i1, tuple<tuple<i2>>> - -// If we only convert the func ops, target materializations are created from -// original tuples produced by `make_tuple` to its constituent elements that the -// converted op (i.e., `return`) expect. -// -// CHECK-FUNC-LABEL: func.func @materializations_tuple_return( -// CHECK-FUNC-SAME: %[[ARG0:.*]]: i1, -// CHECK-FUNC-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { -// CHECK-FUNC-DAG: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<> -// CHECK-FUNC-DAG: %[[V1:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2> -// CHECK-FUNC-DAG: %[[V2:.*]] = "test.make_tuple"(%[[V1]]) : (tuple<i2>) -> tuple<tuple<i2>> -// CHECK-FUNC-DAG: %[[V3:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]], %[[V2]]) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>> -// CHECK-FUNC-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<> -// CHECK-FUNC-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1 -// CHECK-FUNC-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>> -// CHECK-FUNC-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2> -// CHECK-FUNC-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) <{index = 0 : i32}> : (tuple<i2>) -> i2 -// CHECK-FUNC-DAG: return %[[V5]], %[[V8]] : i1, i2 - -// If we convert both tuple and func ops, basically everything disappears. -// -// CHECK-BOTH-LABEL: func.func @materializations_tuple_return( -// CHECK-BOTH-SAME: %[[ARG0:.*]]: i1, -// CHECK-BOTH-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { -// CHECK-BOTH-DAG: return %[[ARG0]], %[[ARG1]] : i1, i2 - -func.func @materializations_tuple_return(%arg0: i1, %arg1: i2) -> tuple<tuple<>, i1, tuple<tuple<i2>>> { - %0 = "test.make_tuple"() : () -> tuple<> - %1 = "test.make_tuple"(%arg1) : (i2) -> tuple<i2> - %2 = "test.make_tuple"(%1) : (tuple<i2>) -> tuple<tuple<i2>> - %3 = "test.make_tuple"(%0, %arg0, %2) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>> - return %3 : tuple<tuple<>, i1, tuple<tuple<i2>>> -} diff --git a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir deleted file mode 100644 index 535ab68e8d893c..00000000000000 --- a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir +++ /dev/null @@ -1,183 +0,0 @@ -// RUN: mlir-opt %s -split-input-file \ -// RUN: -test-one-to-n-type-conversion="convert-func-ops convert-scf-ops" \ -// RUN: | FileCheck %s - -// Test case: Nested 1:N type conversion is carried through scf.if and -// scf.yield. - -// CHECK-LABEL: func.func @if_result( -// CHECK-SAME: %[[ARG0:.*]]: i1, -// CHECK-SAME: %[[ARG1:.*]]: i2, -// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) { -// CHECK-NEXT: %[[V0:.*]]:2 = scf.if %[[ARG2]] -> (i1, i2) { -// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2 -// CHECK-NEXT: } else { -// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2 -// CHECK-NEXT: } -// CHECK-NEXT: return %[[V0]]#0, %[[V0]]#1 : i1, i2 -func.func @if_result(%arg0: tuple<tuple<>, i1, tuple<i2>>, %arg1: i1) -> tuple<tuple<>, i1, tuple<i2>> { - %0 = scf.if %arg1 -> (tuple<tuple<>, i1, tuple<i2>>) { - scf.yield %arg0 : tuple<tuple<>, i1, tuple<i2>> - } else { - scf.yield %arg0 : tuple<tuple<>, i1, tuple<i2>> - } - return %0 : tuple<tuple<>, i1, tuple<i2>> -} - -// ----- - -// Test case: Nested 1:N type conversion is carried through scf.if and -// scf.yield and unconverted ops inside have proper materializations. - -// CHECK-LABEL: func.func @if_tuple_ops( -// CHECK-SAME: %[[ARG0:.*]]: i1, -// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 { -// CHECK-NEXT: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<> -// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]]) : (tuple<>, i1) -> tuple<tuple<>, i1> -// CHECK-NEXT: %[[V2:.*]] = scf.if %[[ARG1]] -> (i1) { -// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V1]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1> -// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<> -// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1 -// CHECK-NEXT: scf.yield %[[V5]] : i1 -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple<tuple<>, i1> -// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<> -// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1 -// CHECK-NEXT: scf.yield %[[V8]] : i1 -// CHECK-NEXT: } -// CHECK-NEXT: return %[[V2]] : i1 -func.func @if_tuple_ops(%arg0: tuple<tuple<>, i1>, %arg1: i1) -> tuple<tuple<>, i1> { - %0 = scf.if %arg1 -> (tuple<tuple<>, i1>) { - %1 = "test.op"(%arg0) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1> - scf.yield %1 : tuple<tuple<>, i1> - } else { - %1 = "test.source"() : () -> tuple<tuple<>, i1> - scf.yield %1 : tuple<tuple<>, i1> - } - return %0 : tuple<tuple<>, i1> -} -// ----- - -// Test case: Nested 1:N type conversion is carried through scf.while, -// scf.condition, and scf.yield. - -// CHECK-LABEL: func.func @while_operands_results( -// CHECK-SAME: %[[ARG0:.*]]: i1, -// CHECK-SAME: %[[ARG1:.*]]: i2, -// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) { -// %[[V0:.*]]:2 = scf.while (%[[ARG3:.*]] = %[[ARG0]], %[[ARG4:.*]] = %[[ARG1]]) : (i1, i2) -> (i1, i2) { -// scf.condition(%arg2) %[[ARG3]], %[[ARG4]] : i1, i2 -// } do { -// ^bb0(%[[ARG5:.*]]: i1, %[[ARG6:.*]]: i2): -// scf.yield %[[ARG5]], %[[ARG4]] : i1, i2 -// } -// return %[[V0]]#0, %[[V0]]#1 : i1, i2 -func.func @while_operands_results(%arg0: tuple<tuple<>, i1, tuple<i2>>, %arg1: i1) -> tuple<tuple<>, i1, tuple<i2>> { - %0 = scf.while (%arg2 = %arg0) : (tuple<tuple<>, i1, tuple<i2>>) -> tuple<tuple<>, i1, tuple<i2>> { - scf.condition(%arg1) %arg2 : tuple<tuple<>, i1, tuple<i2>> - } do { - ^bb0(%arg2: tuple<tuple<>, i1, tuple<i2>>): - scf.yield %arg2 : tuple<tuple<>, i1, tuple<i2>> - } - return %0 : tuple<tuple<>, i1, tuple<i2>> -} - -// ----- - -// Test case: Nested 1:N type conversion is carried through scf.while, -// scf.condition, and unconverted ops inside have proper materializations. - -// CHECK-LABEL: func.func @while_tuple_ops( -// CHECK-SAME: %[[ARG0:.*]]: i1, -// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 { -// CHECK-NEXT: %[[V0:.*]] = scf.while (%[[ARG2:.*]] = %[[ARG0]]) : (i1) -> i1 { -// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"() : () -> tuple<> -// CHECK-NEXT: %[[V2:.*]] = "test.make_tuple"(%[[V1]], %[[ARG2]]) : (tuple<>, i1) -> tuple<tuple<>, i1> -// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V2]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1> -// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<> -// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1 -// CHECK-NEXT: scf.condition(%[[ARG1]]) %[[V5]] : i1 -// CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: i1): -// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple<tuple<>, i1> -// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<> -// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1 -// CHECK-NEXT: scf.yield %[[V8]] : i1 -// CHECK-NEXT: } -// CHECK-NEXT: return %[[V0]] : i1 -func.func @while_tuple_ops(%arg0: tuple<tuple<>, i1>, %arg1: i1) -> tuple<tuple<>, i1> { - %0 = scf.while (%arg2 = %arg0) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1> { - %1 = "test.op"(%arg2) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1> - scf.condition(%arg1) %1 : tuple<tuple<>, i1> - } do { - ^bb0(%arg2: tuple<tuple<>, i1>): - %1 = "test.source"() : () -> tuple<tuple<>, i1> - scf.yield %1 : tuple<tuple<>, i1> - } - return %0 : tuple<tuple<>, i1> -} - -// ----- - -// Test case: Nested 1:N type conversion is carried through scf.for and scf.yield. - -// CHECK-LABEL: func.func @for_operands_results( -// CHECK-SAME: %[[ARG0:.*]]: i1, -// CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { -// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index -// CHECK-NEXT: %[[C1:.+]] = arith.constant 1 : index -// CHECK-NEXT: %[[C10:.+]] = arith.constant 10 : index -// CHECK-NEXT: %[[OUT:.+]]:2 = scf.for %arg2 = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[ITER0:.+]] = %[[ARG0]], %[[ITER1:.+]] = %[[ARG1]]) -> (i1, i2) { -// CHECK-NEXT: scf.yield %[[ITER0]], %[[ITER1]] : i1, i2 -// CHECK-NEXT: } -// CHECK-NEXT: return %[[OUT]]#0, %[[OUT]]#1 : i1, i2 - -func.func @for_operands_results(%arg0: tuple<tuple<>, i1, tuple<i2>>) -> tuple<tuple<>, i1, tuple<i2>> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c10 = arith.constant 10 : index - - %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %arg0) -> tuple<tuple<>, i1, tuple<i2>> { - scf.yield %acc : tuple<tuple<>, i1, tuple<i2>> - } - - return %0 : tuple<tuple<>, i1, tuple<i2>> -} - -// ----- - -// Test case: Nested 1:N type conversion is carried through scf.for and scf.yield - -// CHECK-LABEL: func.func @for_tuple_ops( -// CHECK-SAME: %[[ARG0:.+]]: i1) -> i1 { -// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index -// CHECK-NEXT: %[[C1:.+]] = arith.constant 1 : index -// CHECK-NEXT: %[[C10:.+]] = arith.constant 10 : index -// CHECK-NEXT: %[[FOR:.+]] = scf.for %arg1 = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[ITER:.+]] = %[[ARG0]]) -> (i1) { -// CHECK-NEXT: %[[V1:.+]] = "test.make_tuple"() : () -> tuple<> -// CHECK-NEXT: %[[V2:.+]] = "test.make_tuple"(%[[V1]], %[[ITER]]) : (tuple<>, i1) -> tuple<tuple<>, i1> -// CHECK-NEXT: %[[V3:.+]] = "test.op"(%[[V2]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1> -// CHECK-NEXT: %[[V4:.+]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<> -// CHECK-NEXT: %[[V5:.+]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1 -// CHECK-NEXT: scf.yield %[[V5]] : i1 -// CHECK-NEXT: } -// CHECK-NEXT: %[[V6:.+]] = "test.make_tuple"() : () -> tuple<> -// CHECK-NEXT: %[[V7:.+]] = "test.make_tuple"(%[[V6]], %[[FOR]]) : (tuple<>, i1) -> tuple<tuple<>, i1> -// CHECK-NEXT: %[[V8:.+]] = "test.op"(%[[V7]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1> -// CHECK-NEXT: %[[V9:.+]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<> -// CHECK-NEXT: %[[V10:.+]] = "test.get_tuple_element"(%[[V8]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1 -// CHECK-NEXT: return %[[V10]] : i1 - -func.func @for_tuple_ops(%arg0: tuple<tuple<>, i1>) -> tuple<tuple<>, i1> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c10 = arith.constant 10 : index - - %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %arg0) -> tuple<tuple<>, i1> { - %1 = "test.op"(%acc) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1> - scf.yield %1 : tuple<tuple<>, i1> - } - - %1 = "test.op"(%0) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1> - return %1 : tuple<tuple<>, i1> -} diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir index 4e641317ac2f3d..55d78d9fedebb3 100644 --- a/mlir/test/Transforms/decompose-call-graph-types.mlir +++ b/mlir/test/Transforms/decompose-call-graph-types.mlir @@ -1,19 +1,11 @@ // RUN: mlir-opt %s -split-input-file -test-decompose-call-graph-types | FileCheck %s -// RUN: mlir-opt %s -split-input-file \ -// RUN: -test-one-to-n-type-conversion="convert-func-ops" \ -// RUN: | FileCheck %s --check-prefix=CHECK-12N - // Test case: Most basic case of a 1:N decomposition, an identity function. // CHECK-LABEL: func @identity( // CHECK-SAME: %[[ARG0:.*]]: i1, // CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { // CHECK: return %[[ARG0]], %[[ARG1]] : i1, i32 -// CHECK-12N-LABEL: func @identity( -// CHECK-12N-SAME: %[[ARG0:.*]]: i1, -// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { -// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i32 func.func @identity(%arg0: tuple<i1, i32>) -> tuple<i1, i32> { return %arg0 : tuple<i1, i32> } @@ -25,9 +17,6 @@ func.func @identity(%arg0: tuple<i1, i32>) -> tuple<i1, i32> { // CHECK-LABEL: func @identity_1_to_1_no_materializations( // CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { // CHECK: return %[[ARG0]] : i1 -// CHECK-12N-LABEL: func @identity_1_to_1_no_materializations( -// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 { -// CHECK-12N: return %[[ARG0]] : i1 func.func @identity_1_to_1_no_materializations(%arg0: tuple<i1>) -> tuple<i1> { return %arg0 : tuple<i1> } @@ -39,9 +28,6 @@ func.func @identity_1_to_1_no_materializations(%arg0: tuple<i1>) -> tuple<i1> { // CHECK-LABEL: func @recursive_decomposition( // CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { // CHECK: return %[[ARG0]] : i1 -// CHECK-12N-LABEL: func @recursive_decomposition( -// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 { -// CHECK-12N: return %[[ARG0]] : i1 func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tuple<tuple<i1>>> { return %arg0 : tuple<tuple<tuple<i1>>> } @@ -54,10 +40,6 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl // CHECK-SAME: %[[ARG0:.*]]: i1, // CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { // CHECK: return %[[ARG0]], %[[ARG1]] : i1, i2 -// CHECK-12N-LABEL: func @mixed_recursive_decomposition( -// CHECK-12N-SAME: %[[ARG0:.*]]: i1, -// CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { -// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i2 func.func @mixed_recursive_decomposition(%arg0: tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>> { return %arg0 : tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>> } @@ -67,7 +49,6 @@ func.func @mixed_recursive_decomposition(%arg0: tuple<tuple<>, tuple<i1>, tuple< // Test case: Check decomposition of calls. // CHECK-LABEL: func private @callee(i1, i32) -> (i1, i32) -// CHECK-12N-LABEL: func private @callee(i1, i32) -> (i1, i32) func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32> // CHECK-LABEL: func @caller( @@ -75,11 +56,6 @@ func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32> // CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { // CHECK: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32) // CHECK: return %[[V0]]#0, %[[V0]]#1 : i1, i32 -// CHECK-12N-LABEL: func @caller( -// CHECK-12N-SAME: %[[ARG0:.*]]: i1, -// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { -// CHECK-12N: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32) -// CHECK-12N: return %[[V0]]#0, %[[V0]]#1 : i1, i32 func.func @caller(%arg0: tuple<i1, i32>) -> tuple<i1, i32> { %0 = call @callee(%arg0) : (tuple<i1, i32>) -> tuple<i1, i32> return %0 : tuple<i1, i32> @@ -90,15 +66,11 @@ func.func @caller(%arg0: tuple<i1, i32>) -> tuple<i1, i32> { // Test case: Type that decomposes to nothing (that is, a 1:0 decomposition). // CHECK-LABEL: func private @callee() -// CHECK-12N-LABEL: func private @callee() func.func private @callee(tuple<>) -> tuple<> // CHECK-LABEL: func @caller() { // CHECK: call @callee() : () -> () // CHECK: return -// CHECK-12N-LABEL: func @caller() { -// CHECK-12N: call @callee() : () -> () -// CHECK-12N: return func.func @caller(%arg0: tuple<>) -> tuple<> { %0 = call @callee(%arg0) : (tuple<>) -> (tuple<>) return %0 : tuple<> @@ -114,11 +86,6 @@ func.func @caller(%arg0: tuple<>) -> tuple<> { // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 -// CHECK-12N-LABEL: func @unconverted_op_result() -> (i1, i32) { -// CHECK-12N: %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple<i1, i32> -// CHECK-12N: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1 -// CHECK-12N: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32 -// CHECK-12N: return %[[RET0]], %[[RET1]] : i1, i32 func.func @unconverted_op_result() -> tuple<i1, i32> { %0 = "test.source"() : () -> (tuple<i1, i32>) return %0 : tuple<i1, i32> @@ -139,16 +106,6 @@ func.func @unconverted_op_result() -> tuple<i1, i32> { // CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 1 : i32}> : (tuple<i1, tuple<i32>>) -> tuple<i32> // CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<i32>) -> i32 // CHECK: return %[[V3]], %[[V5]] : i1, i32 -// CHECK-12N-LABEL: func @nested_unconverted_op_result( -// CHECK-12N-SAME: %[[ARG0:.*]]: i1, -// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { -// CHECK-12N: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple<i32> -// CHECK-12N: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple<i32>) -> tuple<i1, tuple<i32>> -// CHECK-12N: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>> -// CHECK-12N: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 0 : i32}> : (tuple<i1, tuple<i32>>) -> i1 -// CHECK-12N: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 1 : i32}> : (tuple<i1, tuple<i32>>) -> tuple<i32> -// CHECK-12N: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<i32>) -> i32 -// CHECK-12N: return %[[V3]], %[[V5]] : i1, i32 func.func @nested_unconverted_op_result(%arg: tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>> { %0 = "test.op"(%arg) : (tuple<i1, tuple<i32>>) -> (tuple<i1, tuple<i32>>) return %0 : tuple<i1, tuple<i32>> @@ -160,7 +117,6 @@ func.func @nested_unconverted_op_result(%arg: tuple<i1, tuple<i32>>) -> tuple<i1 // This makes sure to test the cases if 1:0, 1:1, and 1:N decompositions. // CHECK-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) -// CHECK-12N-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) // CHECK-LABEL: func @caller( @@ -172,15 +128,6 @@ func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tup // CHECK-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) { // CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) // CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 -// CHECK-12N-LABEL: func @caller( -// CHECK-12N-SAME: %[[I1:.*]]: i1, -// CHECK-12N-SAME: %[[I2:.*]]: i2, -// CHECK-12N-SAME: %[[I3:.*]]: i3, -// CHECK-12N-SAME: %[[I4:.*]]: i4, -// CHECK-12N-SAME: %[[I5:.*]]: i5, -// CHECK-12N-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) { -// CHECK-12N: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) -// CHECK-12N: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 func.func @caller(%arg0: tuple<>, %arg1: i1, %arg2: tuple<i2>, %arg3: i3, %arg4: tuple<i4, i5>, %arg5: i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) { %0, %1, %2, %3, %4, %5 = call @callee(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) return %0, %1, %2, %3, %4, %5 : tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6 diff --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt index 19975f671b081d..c09496be729be2 100644 --- a/mlir/test/lib/Conversion/CMakeLists.txt +++ b/mlir/test/lib/Conversion/CMakeLists.txt @@ -1,5 +1,4 @@ add_subdirectory(ConvertToSPIRV) add_subdirectory(FuncToLLVM) add_subdirectory(MathToVCIX) -add_subdirectory(OneToNTypeConversion) add_subdirectory(VectorToSPIRV) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt deleted file mode 100644 index b72302202f72b0..00000000000000 --- a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt +++ /dev/null @@ -1,21 +0,0 @@ -add_mlir_library(MLIRTestOneToNTypeConversionPass - TestOneToNTypeConversionPass.cpp - - EXCLUDE_FROM_LIBMLIR - - LINK_LIBS PUBLIC - MLIRFuncDialect - MLIRFuncTransforms - MLIRIR - MLIRPass - MLIRSCFDialect - MLIRSCFTransforms - MLIRTestDialect - MLIRTransformUtils - ) - -target_include_directories(MLIRTestOneToNTypeConversionPass - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test - ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test - ) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp deleted file mode 100644 index b18dfd8bb22cb1..00000000000000 --- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp +++ /dev/null @@ -1,261 +0,0 @@ -//===- TestOneToNTypeConversionPass.cpp - Test pass 1:N type conv. utils --===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "TestDialect.h" -#include "TestOps.h" -#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" -#include "mlir/Dialect/SCF/Transforms/Patterns.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/OneToNTypeConversion.h" - -using namespace mlir; - -namespace { -/// Test pass that exercises the (poor-man's) 1:N type conversion mechanisms -/// in `applyPartialOneToNConversion` by converting built-in tuples to the -/// elements they consist of as well as some dummy ops operating on these -/// tuples. -struct TestOneToNTypeConversionPass - : public PassWrapper<TestOneToNTypeConversionPass, - OperationPass<ModuleOp>> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneToNTypeConversionPass) - - TestOneToNTypeConversionPass() = default; - TestOneToNTypeConversionPass(const TestOneToNTypeConversionPass &pass) - : PassWrapper(pass) {} - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert<test::TestDialect>(); - } - - StringRef getArgument() const final { - return "test-one-to-n-type-conversion"; - } - - StringRef getDescription() const final { - return "Test pass for 1:N type conversion"; - } - - Option<bool> convertFuncOps{*this, "convert-func-ops", - llvm::cl::desc("Enable conversion on func ops"), - llvm::cl::init(false)}; - - Option<bool> convertSCFOps{*this, "convert-scf-ops", - llvm::cl::desc("Enable conversion on scf ops"), - llvm::cl::init(false)}; - - Option<bool> convertTupleOps{*this, "convert-tuple-ops", - llvm::cl::desc("Enable conversion on tuple ops"), - llvm::cl::init(false)}; - - void runOnOperation() override; -}; - -} // namespace - -namespace mlir { -namespace test { -void registerTestOneToNTypeConversionPass() { - PassRegistration<TestOneToNTypeConversionPass>(); -} -} // namespace test -} // namespace mlir - -namespace { - -/// Test pattern on for the `make_tuple` op from the test dialect that converts -/// this kind of op into it's "decomposed" form, i.e., the elements of the tuple -/// that is being produced by `test.make_tuple`, which are really just the -/// operands of this op. -class ConvertMakeTupleOp - : public OneToNOpConversionPattern<::test::MakeTupleOp> { -public: - using OneToNOpConversionPattern< - ::test::MakeTupleOp>::OneToNOpConversionPattern; - - LogicalResult - matchAndRewrite(::test::MakeTupleOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { - // Simply replace the current op with the converted operands. - rewriter.replaceOp(op, adaptor.getFlatOperands(), - adaptor.getResultMapping()); - return success(); - } -}; - -/// Test pattern on for the `get_tuple_element` op from the test dialect that -/// converts this kind of op into it's "decomposed" form, i.e., instead of -/// "physically" extracting one element from the tuple, we forward the one -/// element of the decomposed form that is being extracted (or the several -/// elements in case that element is a nested tuple). -class ConvertGetTupleElementOp - : public OneToNOpConversionPattern<::test::GetTupleElementOp> { -public: - using OneToNOpConversionPattern< - ::test::GetTupleElementOp>::OneToNOpConversionPattern; - - LogicalResult - matchAndRewrite(::test::GetTupleElementOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { - // Construct mapping for tuple element types. - auto stateType = cast<TupleType>(op->getOperand(0).getType()); - TypeRange originalElementTypes = stateType.getTypes(); - OneToNTypeMapping elementMapping(originalElementTypes); - if (failed(typeConverter->convertSignatureArgs(originalElementTypes, - elementMapping))) - return failure(); - - // Compute converted operands corresponding to original input tuple. - assert(adaptor.getOperands().size() == 1 && - "expected 'get_tuple_element' to have one operand"); - ValueRange convertedTuple = adaptor.getOperands()[0]; - - // Got those converted operands that correspond to the index-th element ofq - // the original input tuple. - size_t index = op.getIndex(); - ValueRange extractedElement = - elementMapping.getConvertedValues(convertedTuple, index); - - rewriter.replaceOp(op, extractedElement, adaptor.getResultMapping()); - - return success(); - } -}; - -} // namespace - -static void -populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add< - // clang-format off - ConvertMakeTupleOp, - ConvertGetTupleElementOp - // clang-format on - >(typeConverter, patterns.getContext()); -} - -/// Creates a sequence of `test.get_tuple_element` ops for all elements of a -/// given tuple value. If some tuple elements are, in turn, tuples, the elements -/// of those are extracted recursively such that the returned values have the -/// same types as `resultTypes.getFlattenedTypes()`. -/// -/// This function has been copied (with small adaptions) from -/// TestDecomposeCallGraphTypes.cpp. -static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder, - TypeRange resultTypes, - ValueRange inputs, - Location loc) { - if (inputs.size() != 1) - return {}; - Value input = inputs.front(); - - TupleType inputType = dyn_cast<TupleType>(input.getType()); - if (!inputType) - return {}; - - SmallVector<Value> values; - for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) { - Value element = builder.create<::test::GetTupleElementOp>( - loc, elementType, input, builder.getI32IntegerAttr(idx)); - if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) { - // Recurse if the current element is also a tuple. - SmallVector<Type> flatRecursiveTypes; - nestedTupleType.getFlattenedTypes(flatRecursiveTypes); - std::optional<SmallVector<Value>> resursiveValues = - buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc); - if (!resursiveValues.has_value()) - return {}; - values.append(resursiveValues.value()); - } else { - values.push_back(element); - } - } - return values; -} - -/// Creates a `test.make_tuple` op out of the given inputs building a tuple of -/// type `resultType`. If that type is nested, each nested tuple is built -/// recursively with another `test.make_tuple` op. -/// -/// This function has been copied (with small adaptions) from -/// TestDecomposeCallGraphTypes.cpp. -static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType, - ValueRange inputs, Location loc) { - // Build one value for each element at this nesting level. - SmallVector<Value> elements; - elements.reserve(resultType.getTypes().size()); - ValueRange::iterator inputIt = inputs.begin(); - for (Type elementType : resultType.getTypes()) { - if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) { - // Determine how many input values are needed for the nested elements of - // the nested TupleType and advance inputIt by that number. - // TODO: We only need the *number* of nested types, not the types itself. - // Maybe it's worth adding a more efficient overload? - SmallVector<Type> nestedFlattenedTypes; - nestedTupleType.getFlattenedTypes(nestedFlattenedTypes); - size_t numNestedFlattenedTypes = nestedFlattenedTypes.size(); - ValueRange nestedFlattenedelements(inputIt, - inputIt + numNestedFlattenedTypes); - inputIt += numNestedFlattenedTypes; - - // Recurse on the values for the nested TupleType. - Value res = buildMakeTupleOp(builder, nestedTupleType, - nestedFlattenedelements, loc); - if (!res) - return Value(); - - // The tuple constructed by the conversion is the element value. - elements.push_back(res); - } else { - // Base case: take one input as is. - elements.push_back(*inputIt++); - } - } - - // Assemble the tuple from the elements. - return builder.create<::test::MakeTupleOp>(loc, resultType, elements); -} - -void TestOneToNTypeConversionPass::runOnOperation() { - ModuleOp module = getOperation(); - auto *context = &getContext(); - - // Assemble type converter. - TypeConverter typeConverter; - - typeConverter.addConversion([](Type type) { return type; }); - typeConverter.addConversion( - [](TupleType tupleType, SmallVectorImpl<Type> &types) { - tupleType.getFlattenedTypes(types); - return success(); - }); - - typeConverter.addArgumentMaterialization(buildMakeTupleOp); - typeConverter.addSourceMaterialization(buildMakeTupleOp); - typeConverter.addTargetMaterialization(buildGetTupleElementOps); - // Test the other target materialization variant that takes the original type - // as additional argument. This materialization function always fails. - typeConverter.addTargetMaterialization( - [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, - Location loc, Type originalType) -> SmallVector<Value> { return {}; }); - - // Assemble patterns. - RewritePatternSet patterns(context); - if (convertTupleOps) - populateDecomposeTuplesTestPatterns(typeConverter, patterns); - if (convertFuncOps) - populateFuncTypeConversionPatterns(typeConverter, patterns); - if (convertSCFOps) - scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); - - // Run conversion. - if (failed(applyPartialOneToNConversion(module, typeConverter, - std::move(patterns)))) - return signalPassFailure(); -} diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 3563d66fa9e798..670f13caa9fafb 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -40,7 +40,6 @@ if(MLIR_INCLUDE_TESTS) MLIRTestDialect MLIRTestDynDialect MLIRTestIR - MLIRTestOneToNTypeConversionPass MLIRTestPass MLIRTestReducer MLIRTestTransforms diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 960f7037a1b61f..3542d7898f32cb 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -132,7 +132,6 @@ void registerTestMeshSimplificationsPass(); void registerTestMultiBuffering(); void registerTestNextAccessPass(); void registerTestNVGPULowerings(); -void registerTestOneToNTypeConversionPass(); void registerTestOpaqueLoc(); void registerTestOpLoweringPasses(); void registerTestPadFusion(); @@ -271,7 +270,6 @@ void registerTestPasses() { mlir::test::registerTestMultiBuffering(); mlir::test::registerTestNextAccessPass(); mlir::test::registerTestNVGPULowerings(); - mlir::test::registerTestOneToNTypeConversionPass(); mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestOpLoweringPasses(); mlir::test::registerTestPadFusion(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index f1192d069fa5f5..c7a3ded63b48a0 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7986,7 +7986,6 @@ cc_library( "include/mlir/Transforms/GreedyPatternRewriteDriver.h", "include/mlir/Transforms/Inliner.h", "include/mlir/Transforms/LoopInvariantCodeMotionUtils.h", - "include/mlir/Transforms/OneToNTypeConversion.h", "include/mlir/Transforms/RegionUtils.h", "include/mlir/Transforms/WalkPatternRewriteDriver.h", ], @@ -9900,7 +9899,6 @@ cc_binary( "//mlir/test:TestMemRef", "//mlir/test:TestMesh", "//mlir/test:TestNVGPU", - "//mlir/test:TestOneToNTypeConversion", "//mlir/test:TestPDLL", "//mlir/test:TestPass", "//mlir/test:TestReducer", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel index 7d51a3829e9120..a010809274e4c2 100644 --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -631,23 +631,6 @@ cc_library( ], ) -cc_library( - name = "TestOneToNTypeConversion", - srcs = glob(["lib/Conversion/OneToNTypeConversion/*.cpp"]), - includes = ["lib/Dialect/Test"], - deps = [ - ":TestDialect", - "//mlir:FuncDialect", - "//mlir:FuncTransforms", - "//mlir:IR", - "//mlir:Pass", - "//mlir:SCFDialect", - "//mlir:SCFTransforms", - "//mlir:TransformUtils", - "//mlir:Transforms", - ], -) - cc_library( name = "TestVectorToSPIRV", srcs = glob(["lib/Conversion/VectorToSPIRV/*.cpp"]), _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits