https://github.com/NexMing created https://github.com/llvm/llvm-project/pull/168703
This patch begins a long-term effort to establish an incremental FIR to MLIR lowering path while preserving the existing FIR to LLVM pipeline. It introduces the core conversion infrastructure and demonstrates the approach with example conversions such as convert `fir.load` to `memref.load`. The new lowering can coexist with the current pipeline, enabling mixed-mode lowering without affecting existing behavior. Future patches will extend the coverage to more operations and types. >From 9186d8ed4c61f5277c08b57b34cbb120dfcacf28 Mon Sep 17 00:00:00 2001 From: yanming <[email protected]> Date: Thu, 13 Nov 2025 13:37:15 +0800 Subject: [PATCH 1/4] [FIR][Lowering] Add FIRToMLIR pass. --- .../include/flang/Optimizer/Support/InitFIR.h | 8 ++++- .../flang/Optimizer/Transforms/Passes.td | 11 +++++++ flang/lib/Optimizer/Transforms/CMakeLists.txt | 1 + .../Optimizer/Transforms/ConvertFIRToMLIR.cpp | 30 +++++++++++++++++++ 4 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h index 67e9287ddad4f..b90badf8ede0f 100644 --- a/flang/include/flang/Optimizer/Support/InitFIR.h +++ b/flang/include/flang/Optimizer/Support/InitFIR.h @@ -32,6 +32,8 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenACC/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -54,7 +56,8 @@ namespace fir::support { mlir::NVVM::NVVMDialect, mlir::gpu::GPUDialect, \ mlir::index::IndexDialect, mif::MIFDialect -#define FLANG_CODEGEN_DIALECT_LIST FIRCodeGenDialect, mlir::LLVM::LLVMDialect +#define FLANG_CODEGEN_DIALECT_LIST \ + FIRCodeGenDialect, mlir::memref::MemRefDialect, mlir::LLVM::LLVMDialect // The definitive list of dialects used by flang. #define FLANG_DIALECT_LIST \ @@ -129,6 +132,9 @@ inline void registerMLIRPassesForFortranTools() { mlir::affine::registerAffineLoopTilingPass(); mlir::affine::registerAffineDataCopyGenerationPass(); + mlir::registerMem2RegPass(); + mlir::memref::registerMemRefPasses(); + mlir::registerLowerAffinePass(); } diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td index bb2509b1747d5..0bf1537b2215c 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -87,6 +87,17 @@ def FIRToSCFPass : Pass<"fir-to-scf"> { ]; } +def ConvertFIRToMLIRPass : Pass<"fir-to-mlir", "mlir::ModuleOp"> { + let summary = "Convert the FIR dialect module to MLIR standard dialects."; + let description = [{ + Convert the FIR dialect module to MLIR standard dialects. + }]; + let dependentDialects = [ + "fir::FIROpsDialect", "fir::FIRCodeGenDialect", "mlir::scf::SCFDialect", + "mlir::memref::MemRefDialect", "mlir::affine::AffineDialect" + ]; +} + def AnnotateConstantOperands : Pass<"annotate-constant"> { let summary = "Annotate constant operands to all FIR operations"; let description = [{ diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt index 0388439f89a54..a6423b3dea5a9 100644 --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -36,6 +36,7 @@ add_flang_library(FIRTransforms SimplifyFIROperations.cpp OptimizeArrayRepacking.cpp ConvertComplexPow.cpp + ConvertFIRToMLIR.cpp MIFOpConversion.cpp DEPENDS diff --git a/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp b/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp new file mode 100644 index 0000000000000..a24d011da50c9 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp @@ -0,0 +1,30 @@ +//===-- ConvertFIRToMLIR.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +namespace fir { +#define GEN_PASS_DEF_CONVERTFIRTOMLIRPASS +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +namespace { +class ConvertFIRToMLIRPass + : public fir::impl::ConvertFIRToMLIRPassBase<ConvertFIRToMLIRPass> { +public: + void runOnOperation() override; +}; +} // namespace + +void ConvertFIRToMLIRPass::runOnOperation() { + // TODO: +} >From 1de4ae5f9d688710da4c6bb71e5c8271186de1a8 Mon Sep 17 00:00:00 2001 From: yanming <[email protected]> Date: Wed, 19 Nov 2025 16:30:04 +0800 Subject: [PATCH 2/4] [FIR][Lowering] Add a flag to select lowering through MLIR. --- clang/include/clang/Options/Options.td | 5 +++++ clang/lib/Driver/ToolChains/Flang.cpp | 1 + flang/include/flang/Lower/LoweringOptions.def | 3 +++ .../flang/Optimizer/Passes/Pipelines.h | 2 ++ flang/include/flang/Tools/CrossToolHelpers.h | 1 + flang/lib/Frontend/CompilerInvocation.cpp | 4 ++++ flang/lib/Frontend/FrontendActions.cpp | 1 + flang/lib/Optimizer/Passes/Pipelines.cpp | 20 ++++++++++++++----- flang/test/Driver/frontend-forwarding.f90 | 2 ++ 9 files changed, 34 insertions(+), 5 deletions(-) diff --git a/clang/include/clang/Options/Options.td b/clang/include/clang/Options/Options.td index cda11fdc94230..cd8409de8c5a9 100644 --- a/clang/include/clang/Options/Options.td +++ b/clang/include/clang/Options/Options.td @@ -7217,6 +7217,11 @@ def flang_deprecated_no_hlfir : Flag<["-"], "flang-deprecated-no-hlfir">, Flags<[HelpHidden]>, Visibility<[FlangOption, FC1Option]>, HelpText<"Do not use HLFIR lowering (deprecated)">; +def flang_experimental_lower_through_mlir + : Flag<["-"], "flang-experimental-lower-through-mlir">, + Flags<[HelpHidden]>, Visibility<[FlangOption, FC1Option]>, + HelpText<"Lower form FIR through MLIR to LLVM (experimental)">; + //===----------------------------------------------------------------------===// // FLangOption + CoreOption + NoXarchOption //===----------------------------------------------------------------------===// diff --git a/clang/lib/Driver/ToolChains/Flang.cpp b/clang/lib/Driver/ToolChains/Flang.cpp index 270904de544d6..e294ac59af73d 100644 --- a/clang/lib/Driver/ToolChains/Flang.cpp +++ b/clang/lib/Driver/ToolChains/Flang.cpp @@ -222,6 +222,7 @@ void Flang::addCodegenOptions(const ArgList &Args, {options::OPT_fdo_concurrent_to_openmp_EQ, options::OPT_flang_experimental_hlfir, options::OPT_flang_deprecated_no_hlfir, + options::OPT_flang_experimental_lower_through_mlir, options::OPT_fno_ppc_native_vec_elem_order, options::OPT_fppc_native_vec_elem_order, options::OPT_finit_global_zero, options::OPT_fno_init_global_zero, options::OPT_frepack_arrays, diff --git a/flang/include/flang/Lower/LoweringOptions.def b/flang/include/flang/Lower/LoweringOptions.def index 39f197d8d35c8..01fc96b78df50 100644 --- a/flang/include/flang/Lower/LoweringOptions.def +++ b/flang/include/flang/Lower/LoweringOptions.def @@ -38,6 +38,9 @@ ENUM_LOWERINGOPT(Underscoring, unsigned, 1, 1) /// (i.e. wraps around as two's complement). Off by default. ENUM_LOWERINGOPT(IntegerWrapAround, unsigned, 1, 0) +/// If true, lower form FIR through MLIR to LLVM +ENUM_LOWERINGOPT(LowerThroughMLIR, unsigned, 1, 0) + /// If true (default), follow Fortran 2003 rules for (re)allocating /// the allocatable on the left side of the intrinsic assignment, /// if LHS and RHS have mismatching shapes/types. diff --git a/flang/include/flang/Optimizer/Passes/Pipelines.h b/flang/include/flang/Optimizer/Passes/Pipelines.h index 70b9341347244..fc1ebaf7d24f7 100644 --- a/flang/include/flang/Optimizer/Passes/Pipelines.h +++ b/flang/include/flang/Optimizer/Passes/Pipelines.h @@ -18,10 +18,12 @@ #include "flang/Optimizer/Passes/CommandLineOpts.h" #include "flang/Optimizer/Transforms/Passes.h" #include "flang/Tools/CrossToolHelpers.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/OpenMP/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/flang/include/flang/Tools/CrossToolHelpers.h b/flang/include/flang/Tools/CrossToolHelpers.h index e964882ef6dac..0dcb99e1eb5b1 100644 --- a/flang/include/flang/Tools/CrossToolHelpers.h +++ b/flang/include/flang/Tools/CrossToolHelpers.h @@ -137,6 +137,7 @@ struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks { bool EnableOpenMP = false; ///< Enable OpenMP lowering. bool EnableOpenMPSimd = false; ///< Enable OpenMP simd-only mode. bool SkipConvertComplexPow = false; ///< Do not run complex pow conversion. + bool LowerThroughMLIR = false; ///< Lower form FIR through MLIR to LLVM std::string InstrumentFunctionEntry = ""; ///< Name of the instrument-function that is called on each ///< function-entry diff --git a/flang/lib/Frontend/CompilerInvocation.cpp b/flang/lib/Frontend/CompilerInvocation.cpp index 893121fe01f27..8c3fde0a27153 100644 --- a/flang/lib/Frontend/CompilerInvocation.cpp +++ b/flang/lib/Frontend/CompilerInvocation.cpp @@ -1580,6 +1580,10 @@ bool CompilerInvocation::createFromArgs( invoc.loweringOpts.setLowerToHighLevelFIR(false); } + // -flang-experimental-lower-through-mlir + invoc.loweringOpts.setLowerThroughMLIR( + args.hasArg(clang::options::OPT_flang_experimental_lower_through_mlir)); + // -fno-ppc-native-vector-element-order if (args.hasArg(clang::options::OPT_fno_ppc_native_vec_elem_order)) { invoc.loweringOpts.setNoPPCNativeVecElemOrder(true); diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp index 159d08a2797b3..0cb241f209522 100644 --- a/flang/lib/Frontend/FrontendActions.cpp +++ b/flang/lib/Frontend/FrontendActions.cpp @@ -769,6 +769,7 @@ void CodeGenAction::generateLLVMIR() { config.NSWOnLoopVarInc = false; config.ComplexRange = opts.getComplexRange(); + config.LowerThroughMLIR = invoc.getLoweringOpts().getLowerThroughMLIR(); // Create the pass pipeline fir::createMLIRToLLVMPassPipeline(pm, config, getCurrentFile()); diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 103e736accca0..6aa81a1b44c6b 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -109,6 +109,18 @@ void addDebugInfoPass(mlir::PassManager &pm, void addFIRToLLVMPass(mlir::PassManager &pm, const MLIRToLLVMPassPipelineConfig &config) { + if (disableFirToLlvmIr) + return; + + if (config.LowerThroughMLIR) { + pm.addPass(createConvertFIRToMLIRPass()); + pm.addPass(mlir::memref::createFoldMemRefAliasOpsPass()); + pm.addPass(mlir::createMem2Reg()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); + } + fir::FIRToLLVMPassOptions options; options.ignoreMissingTypeDescriptors = ignoreMissingTypeDescriptors; options.skipExternalRttiDefinition = skipExternalRttiDefinition; @@ -117,13 +129,11 @@ void addFIRToLLVMPass(mlir::PassManager &pm, options.typeDescriptorsRenamedForAssembly = !disableCompilerGeneratedNamesConversion; options.ComplexRange = config.ComplexRange; - addPassConditionally(pm, disableFirToLlvmIr, - [&]() { return fir::createFIRToLLVMPass(options); }); + pm.addPass(fir::createFIRToLLVMPass(options)); + // The dialect conversion framework may leave dead unrealized_conversion_cast // ops behind, so run reconcile-unrealized-casts to clean them up. - addPassConditionally(pm, disableFirToLlvmIr, [&]() { - return mlir::createReconcileUnrealizedCastsPass(); - }); + pm.addPass(mlir::createReconcileUnrealizedCastsPass()); } void addLLVMDialectToLLVMPass(mlir::PassManager &pm, diff --git a/flang/test/Driver/frontend-forwarding.f90 b/flang/test/Driver/frontend-forwarding.f90 index 952937168c95d..ab9e5e8b4d088 100644 --- a/flang/test/Driver/frontend-forwarding.f90 +++ b/flang/test/Driver/frontend-forwarding.f90 @@ -20,6 +20,7 @@ ! RUN: -fversion-loops-for-stride \ ! RUN: -flang-experimental-hlfir \ ! RUN: -flang-deprecated-no-hlfir \ +! RUN: -flang-experimental-lower-through-mlir \ ! RUN: -fno-ppc-native-vector-element-order \ ! RUN: -fppc-native-vector-element-order \ ! RUN: -mllvm -print-before-all \ @@ -51,6 +52,7 @@ ! CHECK: "-fversion-loops-for-stride" ! CHECK: "-flang-experimental-hlfir" ! CHECK: "-flang-deprecated-no-hlfir" +! CHECK: "-flang-experimental-lower-through-mlir" ! CHECK: "-fno-ppc-native-vector-element-order" ! CHECK: "-fppc-native-vector-element-order" ! CHECK: "-Rpass" >From c86bcbffee9f812ce2d27c18900ca2ac708e3c69 Mon Sep 17 00:00:00 2001 From: yanming <[email protected]> Date: Wed, 19 Nov 2025 18:27:39 +0800 Subject: [PATCH 3/4] [FIR][Lowering] Add lowering for `fir.convert` between `fir.ref` and `memref` type. --- flang/lib/Optimizer/CodeGen/CodeGen.cpp | 25 ++++++++++++++- flang/test/Fir/convert-to-llvm.fir | 41 +++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index f96d45d3f6b66..7959c846a2418 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -1008,6 +1008,29 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> { rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(convert, toTy, op0); return mlir::success(); } + // Pointer to MemRef conversion. + if (mlir::isa<mlir::MemRefType>(toFirTy)) { + auto dstMemRef = mlir::MemRefDescriptor::poison(rewriter, loc, toTy); + dstMemRef.setAlignedPtr(rewriter, loc, op0); + dstMemRef.setOffset( + rewriter, loc, + createIndexAttrConstant(rewriter, loc, getIndexType(), 0)); + rewriter.replaceOp(convert, {dstMemRef}); + return mlir::success(); + } + } else if (mlir::isa<mlir::MemRefType>(fromFirTy) && + mlir::isa<mlir::LLVM::LLVMPointerType>(toTy)) { + // MemRef to pointer conversion. + auto srcMemRef = mlir::MemRefDescriptor(op0); + mlir::Type elementType = typeConverter->convertType( + mlir::cast<mlir::MemRefType>(fromFirTy).getElementType()); + mlir::Value srcBasePtr = srcMemRef.alignedPtr(rewriter, loc); + mlir::Value srcOffset = srcMemRef.offset(rewriter, loc); + mlir::Value srcPtr = + mlir::LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(), + elementType, srcBasePtr, srcOffset); + rewriter.replaceOp(convert, srcPtr); + return mlir::success(); } return emitError(loc) << "cannot convert " << fromTy << " to " << toTy; } @@ -4326,7 +4349,7 @@ class FIRToLLVMLowering target.addLegalDialect<mlir::gpu::GPUDialect>(); // required NOPs for applying a full conversion - target.addLegalOp<mlir::ModuleOp>(); + target.addLegalOp<mlir::ModuleOp, mlir::UnrealizedConversionCastOp>(); // If we're on Windows, we might need to rename some libm calls. bool isMSVC = fir::getTargetTriple(mod).isOSMSVCRT(); diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir index 864368740be02..41c7ea992c29c 100644 --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -882,6 +882,47 @@ func.func @convert_record(%arg0 : !fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.stru // ----- +// Test `fir.convert` operation conversion between `memref` and `fir.ref`. + +func.func @convert_to_memref(%arg0 : !fir.ref<i32>) -> memref<i32, strided<[], offset; ?>>{ + %0 = fir.convert %arg0 : (!fir.ref<i32>) -> memref<i32, strided<[], offset; ?>> + return %0 : memref<i32, strided<[], offset; ?>> +} + +// CHECK-LABEL: llvm.func @convert_to_memref( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64)> { +// CHECK: %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[ARG0]], %[[MLIR_0]][1] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[MLIR_1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[MLIR_1]], %[[INSERTVALUE_0]][2] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: llvm.return %[[INSERTVALUE_1]] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: } + +// ----- + +// Test `fir.convert` operation conversion between `memref` and `fir.ref`. + +func.func @convert_from_memref(%arg0 : memref<i32, strided<[], offset; ?>>) -> !fir.ref<i32> { + %0 = fir.convert %arg0 : (memref<i32, strided<[], offset; ?>>) -> !fir.ref<i32> + return %0 : !fir.ref<i32> +} + +// CHECK-LABEL: llvm.func @convert_from_memref( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr, +// CHECK-SAME: %[[ARG2:.*]]: i64) -> !llvm.ptr { +// CHECK: %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[ARG0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[ARG1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[ARG2]], %[[INSERTVALUE_1]][2] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[INSERTVALUE_2]][1] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[INSERTVALUE_2]][2] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[GETELEMENTPTR_0:.*]] = llvm.getelementptr %[[EXTRACTVALUE_0]]{{\[}}%[[EXTRACTVALUE_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK: llvm.return %[[GETELEMENTPTR_0]] : !llvm.ptr +// CHECK: } + +// ----- + // Test `fir.store` --> `llvm.store` conversion func.func @test_store_index(%val_to_store : index, %addr : !fir.ref<index>) { >From c9eea4bf0ae40c4a6d10cb261af14e3970f9ba6d Mon Sep 17 00:00:00 2001 From: yanming <[email protected]> Date: Wed, 19 Nov 2025 18:56:14 +0800 Subject: [PATCH 4/4] [FIR][Lowering] Add fir to mlir core dialect patterns. --- .../Optimizer/Transforms/ConvertFIRToMLIR.cpp | 198 +++++++++++++++++- flang/test/Fir/convert-to-mlir.fir | 135 ++++++++++++ 2 files changed, 332 insertions(+), 1 deletion(-) create mode 100644 flang/test/Fir/convert-to-mlir.fir diff --git a/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp b/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp index a24d011da50c9..bd4374e9a4aa2 100644 --- a/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp +++ b/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp @@ -6,11 +6,13 @@ // //===----------------------------------------------------------------------===// +#include "flang/Optimizer/Dialect/FIRCG/CGOps.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Transforms/DialectConversion.h" namespace fir { #define GEN_PASS_DEF_CONVERTFIRTOMLIRPASS @@ -23,8 +25,202 @@ class ConvertFIRToMLIRPass public: void runOnOperation() override; }; + +class FIRLoadOpLowering : public mlir::OpConversionPattern<fir::LoadOp> { +public: + using mlir::OpConversionPattern<fir::LoadOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(fir::LoadOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + if (!mlir::MemRefType::isValidElementType(op.getType())) + return mlir::failure(); + + rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, adaptor.getMemref(), + mlir::ValueRange{}); + return mlir::success(); + } +}; + +class FIRStoreOpLowering : public mlir::OpConversionPattern<fir::StoreOp> { +public: + using mlir::OpConversionPattern<fir::StoreOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(fir::StoreOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + if (!mlir::MemRefType::isValidElementType(op.getValue().getType())) + return mlir::failure(); + + rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>( + op, adaptor.getValue(), adaptor.getMemref(), mlir::ValueRange{}); + return mlir::success(); + } +}; + +class FIRConvertOpLowering : public mlir::OpConversionPattern<fir::ConvertOp> { +public: + using mlir::OpConversionPattern<fir::ConvertOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(fir::ConvertOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto srcVal = adaptor.getValue(); + auto srcType = srcVal.getType(); + auto dstType = getTypeConverter()->convertType(op.getType()); + + if (srcType == dstType) { + rewriter.replaceOp(op, mlir::ValueRange{srcVal}); + } else if (srcType.isIntOrIndex() && srcType.isIntOrIndex()) { + if (srcType.isIndex() || dstType.isIndex()) { + rewriter.replaceOpWithNewOp<mlir::arith::IndexCastOp>(op, dstType, + srcVal); + } else if (srcType.getIntOrFloatBitWidth() < + dstType.getIntOrFloatBitWidth()) { + rewriter.replaceOpWithNewOp<mlir::arith::ExtSIOp>(op, dstType, srcVal); + } else { + rewriter.replaceOpWithNewOp<mlir::arith::TruncIOp>(op, dstType, srcVal); + } + } else if (srcType.isFloat() && dstType.isFloat()) { + if (srcType.getIntOrFloatBitWidth() < dstType.getIntOrFloatBitWidth()) { + rewriter.replaceOpWithNewOp<mlir::arith::ExtFOp>(op, dstType, srcVal); + } else { + rewriter.replaceOpWithNewOp<mlir::arith::TruncFOp>(op, dstType, srcVal); + } + } else { + return mlir::failure(); + } + + return mlir::success(); + } +}; + +class FIRAllocOpLowering : public mlir::OpConversionPattern<fir::AllocaOp> { +public: + using mlir::OpConversionPattern<fir::AllocaOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(fir::AllocaOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + if (!mlir::MemRefType::isValidElementType(op.getAllocatedType()) || + op.hasLenParams()) + return mlir::failure(); + + auto dstType = getTypeConverter()->convertType(op.getType()); + auto allocaOp = mlir::memref::AllocaOp::create( + rewriter, op.getLoc(), + mlir::MemRefType::get({}, op.getAllocatedType())); + allocaOp->setAttrs(op->getAttrs()); + rewriter.replaceOpWithNewOp<mlir::memref::CastOp>(op, dstType, allocaOp); + return mlir::success(); + } +}; + +class FIRXArrayCoorOpLowering + : public mlir::OpConversionPattern<fir::cg::XArrayCoorOp> { +public: + using mlir::OpConversionPattern<fir::cg::XArrayCoorOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(fir::cg::XArrayCoorOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + if (!mlir::isa<fir::ReferenceType>(op.getMemref().getType())) + return mlir::failure(); + + mlir::Location loc = op.getLoc(); + auto metadata = mlir::memref::ExtractStridedMetadataOp::create( + rewriter, loc, adaptor.getMemref()); + auto base = metadata.getBaseBuffer(); + auto offset = metadata.getOffset(); + mlir::ValueRange shape = adaptor.getShape(); + unsigned rank = op.getRank(); + + assert(rank > 0 && "expected rank to be greater than zero"); + + auto sizes = llvm::to_vector_of<mlir::OpFoldResult>(llvm::reverse(shape)); + mlir::SmallVector<mlir::OpFoldResult> strides(rank); + + strides[rank - 1] = rewriter.getIndexAttr(1); + mlir::Value stride = mlir::arith::ConstantIndexOp::create(rewriter, loc, 1); + for (unsigned i = 1; i < rank; ++i) { + stride = mlir::arith::MulIOp::create(rewriter, loc, stride, shape[i - 1]); + strides[rank - 1 - i] = stride; + } + + mlir::Value memref = mlir::memref::ReinterpretCastOp::create( + rewriter, loc, base, offset, sizes, strides); + + mlir::SmallVector<mlir::OpFoldResult> oneAttrs(rank, + rewriter.getIndexAttr(1)); + auto one = mlir::arith::ConstantIndexOp::create(rewriter, loc, 1); + auto offsets = llvm::map_to_vector( + llvm::reverse(adaptor.getIndices()), + [&](mlir::Value idx) -> mlir::OpFoldResult { + if (idx.getType().isInteger()) + idx = mlir::arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), idx); + + assert(idx.getType().isIndex() && "expected index type"); + idx = mlir::arith::SubIOp::create(rewriter, loc, idx, one); + return idx; + }); + + auto subview = mlir::memref::SubViewOp::create( + rewriter, loc, + mlir::cast<mlir::MemRefType>( + getTypeConverter()->convertType(op.getType())), + memref, offsets, oneAttrs, oneAttrs); + + rewriter.replaceOp(op, mlir::ValueRange{subview}); + return mlir::success(); + } +}; + } // namespace +static mlir::TypeConverter prepareTypeConverter() { + mlir::TypeConverter converter; + converter.addConversion([](mlir::Type ty) { return ty; }); + converter.addConversion([&](fir::ReferenceType ty) { + auto eleTy = ty.getElementType(); + if (auto sequenceTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) + eleTy = sequenceTy.getElementType(); + + auto layout = mlir::StridedLayoutAttr::get(ty.getContext(), + mlir::ShapedType::kDynamic, {}); + return mlir::MemRefType::get({}, eleTy, layout); + }); + + // Use fir.convert as the bridge so that we don't need to pull in patterns for + // other dialects. + auto materializeProcedure = [](mlir::OpBuilder &builder, mlir::Type type, + mlir::ValueRange inputs, + mlir::Location loc) -> mlir::Value { + auto convertOp = fir::ConvertOp::create(builder, loc, type, inputs); + return convertOp; + }; + + converter.addSourceMaterialization(materializeProcedure); + converter.addTargetMaterialization(materializeProcedure); + return converter; +} + void ConvertFIRToMLIRPass::runOnOperation() { - // TODO: + mlir::MLIRContext *ctx = &getContext(); + mlir::ModuleOp theModule = getOperation(); + mlir::TypeConverter converter = prepareTypeConverter(); + mlir::RewritePatternSet patterns(ctx); + + patterns.add<FIRAllocOpLowering, FIRLoadOpLowering, FIRStoreOpLowering, + FIRConvertOpLowering, FIRXArrayCoorOpLowering>(converter, ctx); + + mlir::ConversionTarget target(getContext()); + + target.addLegalDialect<mlir::arith::ArithDialect, mlir::affine::AffineDialect, + mlir::memref::MemRefDialect, mlir::scf::SCFDialect>(); + + if (mlir::failed(mlir::applyPartialConversion(theModule, target, + std::move(patterns)))) { + signalPassFailure(); + } } diff --git a/flang/test/Fir/convert-to-mlir.fir b/flang/test/Fir/convert-to-mlir.fir new file mode 100644 index 0000000000000..3265349969e83 --- /dev/null +++ b/flang/test/Fir/convert-to-mlir.fir @@ -0,0 +1,135 @@ +// RUN: fir-opt --split-input-file --fir-to-mlir %s | FileCheck %s + +//=================================================== +// SUMMARY: Tests for FIR --> MLIR core dialects conversion +//=================================================== + +// Test `fir.load` --> `memref.load` conversion + +func.func @test_load_f32(%addr : !fir.ref<f32>) -> f32 { + %0 = fir.load %addr : !fir.ref<f32> + return %0 : f32 +} + +// CHECK-LABEL: func.func @test_load_f32( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<f32>) -> f32 { +// CHECK: %[[CONVERT_0:.*]] = fir.convert %[[ARG0]] : (!fir.ref<f32>) -> memref<f32, strided<[], offset: ?>> +// CHECK: %[[LOAD_0:.*]] = memref.load %[[CONVERT_0]][] : memref<f32, strided<[], offset: ?>> +// CHECK: return %[[LOAD_0]] : f32 +// CHECK: } + +// ----- + +// Test `fir.store` --> `memref.store` conversion + +func.func @test_store_f32(%val : f32, %addr : !fir.ref<f32>) { + fir.store %val to %addr : !fir.ref<f32> + return +} + +// CHECK-LABEL: func.func @test_store_f32( +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: !fir.ref<f32>) { +// CHECK: %[[CONVERT_0:.*]] = fir.convert %[[ARG1]] : (!fir.ref<f32>) -> memref<f32, strided<[], offset: ?>> +// CHECK: memref.store %[[ARG0]], %[[CONVERT_0]][] : memref<f32, strided<[], offset: ?>> +// CHECK: return +// CHECK: } + +// ----- + +// Test `fir.convert` operation conversion between Interger and Index type. + +func.func @convert_between_int_and_index(%arg0 : i32) -> i64 { + %0 = fir.convert %arg0 : (i32) -> index + %1 = fir.convert %0 : (index) -> i64 + return %1 : i64 +} + +// CHECK-LABEL: func.func @convert_between_int_and_index( +// CHECK-SAME: %[[ARG0:.*]]: i32) -> i64 { +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[ARG0]] : i32 to index +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[INDEX_CAST_0]] : index to i64 +// CHECK: return %[[INDEX_CAST_1]] : i64 +// CHECK: } + +// ----- + +// Test `fir.convert` operation conversion between Interger type. + +func.func @convert_between_int(%arg0 : i32) -> i16 { + %0 = fir.convert %arg0 : (i32) -> i64 + %1 = fir.convert %0 : (i64) -> i16 + return %1 : i16 +} + +// CHECK-LABEL: func.func @convert_between_int( +// CHECK-SAME: %[[ARG0:.*]]: i32) -> i16 { +// CHECK: %[[EXTSI_0:.*]] = arith.extsi %[[ARG0]] : i32 to i64 +// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[EXTSI_0]] : i64 to i16 +// CHECK: return %[[TRUNCI_0]] : i16 +// CHECK: } + +// ----- + +// Test `fir.convert` operation conversion between Float type. + +func.func @convert_between_fp(%arg0 : f32) -> f16 { + %0 = fir.convert %arg0 : (f32) -> f64 + %1 = fir.convert %0 : (f64) -> f16 + return %1 : f16 +} + +// CHECK-LABEL: func.func @convert_between_fp( +// CHECK-SAME: %[[ARG0:.*]]: f32) -> f16 { +// CHECK: %[[EXTF_0:.*]] = arith.extf %[[ARG0]] : f32 to f64 +// CHECK: %[[TRUNCF_0:.*]] = arith.truncf %[[EXTF_0]] : f64 to f16 +// CHECK: return %[[TRUNCF_0]] : f16 +// CHECK: } + +// ----- + +// Test `fir.alloca` --> `memref.alloca` conversion + +func.func @test_alloca_f32() -> !fir.ref<f32> { + %1 = fir.alloca f32 + return %1 : !fir.ref<f32> +} + +// CHECK-LABEL: func.func @test_alloca_f32() -> !fir.ref<f32> { +// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() {in_type = f32} : memref<f32> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOCA_0]] : memref<f32> to memref<f32, strided<[], offset: ?>> +// CHECK: %[[CONVERT_0:.*]] = fir.convert %[[CAST_0]] : (memref<f32, strided<[], offset: ?>>) -> !fir.ref<f32> +// CHECK: return %[[CONVERT_0]] : !fir.ref<f32> +// CHECK: } + +// ----- + +// Test `fircg.ext_array_coor` conversion. + +func.func @test_ext_array_coor(%arg0: !fir.ref<!fir.array<100x200xf32>>, %i : i64, %j : i64) -> !fir.ref<f32> { + %c200 = arith.constant 200 : index + %c100 = arith.constant 100 : index + %0 = fircg.ext_array_coor %arg0(%c100, %c200)<%i, %j> : (!fir.ref<!fir.array<100x200xf32>>, index, index, i64, i64) -> !fir.ref<f32> + return %0 : !fir.ref<f32> +} + +// CHECK-LABEL: func.func @test_ext_array_coor( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100x200xf32>>, +// CHECK-SAME: %[[ARG1:.*]]: i64, +// CHECK-SAME: %[[ARG2:.*]]: i64) -> !fir.ref<f32> { +// CHECK: %[[CONVERT_0:.*]] = fir.convert %[[ARG0]] : (!fir.ref<!fir.array<100x200xf32>>) -> memref<f32, strided<[], offset: ?>> +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 200 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 100 : index +// CHECK: %[[VAL_0:.*]], %[[EXTRACT_STRIDED_METADATA_0:.*]] = memref.extract_strided_metadata %[[CONVERT_0]] : memref<f32, strided<[], offset: ?>> -> memref<f32>, index +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 1 : index +// CHECK: %[[MULI_0:.*]] = arith.muli %[[CONSTANT_2]], %[[CONSTANT_1]] : index +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[EXTRACT_STRIDED_METADATA_0]]], sizes: {{\[}}%[[CONSTANT_0]], %[[CONSTANT_1]]], strides: {{\[}}%[[MULI_0]], 1] : memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>> +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 1 : index +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[ARG2]] : i64 to index +// CHECK: %[[SUBI_0:.*]] = arith.subi %[[INDEX_CAST_0]], %[[CONSTANT_3]] : index +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ARG1]] : i64 to index +// CHECK: %[[SUBI_1:.*]] = arith.subi %[[INDEX_CAST_1]], %[[CONSTANT_3]] : index +// CHECK: %[[SUBVIEW_0:.*]] = memref.subview %[[REINTERPRET_CAST_0]]{{\[}}%[[SUBI_0]], %[[SUBI_1]]] [1, 1] [1, 1] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<f32, strided<[], offset: ?>> +// CHECK: %[[CONVERT_1:.*]] = fir.convert %[[SUBVIEW_0]] : (memref<f32, strided<[], offset: ?>>) -> !fir.ref<f32> +// CHECK: return %[[CONVERT_1]] : !fir.ref<f32> +// CHECK: } _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
