https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/140268
>From d71c0b7f45582ece43016eb98367251e54e75280 Mon Sep 17 00:00:00 2001 From: Jean Perier <jper...@nvidia.com> Date: Fri, 16 May 2025 08:09:37 -0700 Subject: [PATCH 1/2] [flang] translate derived type array init to attribute if possible --- .../Optimizer/CodeGen/LLVMInsertChainFolder.h | 31 +++ .../include/flang/Optimizer/Dialect/FIROps.td | 5 + flang/lib/Optimizer/CodeGen/CMakeLists.txt | 1 + flang/lib/Optimizer/CodeGen/CodeGen.cpp | 51 +++-- .../CodeGen/LLVMInsertChainFolder.cpp | 204 ++++++++++++++++++ flang/lib/Optimizer/Dialect/FIROps.cpp | 15 ++ .../Fir/convert-and-fold-insert-on-range.fir | 33 +++ 7 files changed, 319 insertions(+), 21 deletions(-) create mode 100644 flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h create mode 100644 flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp create mode 100644 flang/test/Fir/convert-and-fold-insert-on-range.fir diff --git a/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h b/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h new file mode 100644 index 0000000000000..d577c4c0fa70b --- /dev/null +++ b/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h @@ -0,0 +1,31 @@ +//===-- LLVMInsertChainFolder.h -- insertvalue chain folder ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper to fold LLVM dialect llvm.insertvalue chain representing constants +// into an Attribute representation. +// This sits in Flang because it is incomplete and tailored for flang needs. +// +//===----------------------------------------------------------------------===// + +namespace mlir { +class Attribute; +class OpBuilder; +class Value; +} // namespace mlir + +namespace fir { + +/// Attempt to fold an llvm.insertvalue chain into an attribute representation +/// suitable as llvm.constant operand. The returned value will be a null pointer +/// if this is not an llvm.insertvalue result pr if the chain is not a constant, +/// or cannot be represented as an Attribute. The operations are not deleted, +/// but some llvm.insertvalue value operands may be folded with the builder on +/// the way. +mlir::Attribute tryFoldingLLVMInsertChain(mlir::Value insertChainResult, + mlir::OpBuilder &builder); +} // namespace fir diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 458b780806144..dc66885f776f0 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2129,6 +2129,11 @@ def fir_InsertOnRangeOp : fir_OneResultOp<"insert_on_range", [NoMemoryEffect]> { $seq `,` $val custom<CustomRangeSubscript>($coor) attr-dict `:` functional-type(operands, results) }]; + let extraClassDeclaration = [{ + /// Is this insert_on_range inserting on all the values of the result type? + bool isFullRange(); + }]; + let hasVerifier = 1; } diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt index 04480bac552b7..980307db315d9 100644 --- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt +++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt @@ -3,6 +3,7 @@ add_flang_library(FIRCodeGen CodeGen.cpp CodeGenOpenMP.cpp FIROpPatterns.cpp + LLVMInsertChainFolder.cpp LowerRepackArrays.cpp PreCGRewrite.cpp TBAABuilder.cpp diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index ad9119ba4a031..ed76a77ced047 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -14,6 +14,7 @@ #include "flang/Optimizer/CodeGen/CodeGenOpenMP.h" #include "flang/Optimizer/CodeGen/FIROpPatterns.h" +#include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h" #include "flang/Optimizer/CodeGen/TypeConverter.h" #include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIRCG/CGOps.h" @@ -2412,15 +2413,38 @@ struct InsertOnRangeOpConversion doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - llvm::SmallVector<std::int64_t> dims; - auto type = adaptor.getOperands()[0].getType(); + auto arrayType = adaptor.getSeq().getType(); // Iteratively extract the array dimensions from the type. + llvm::SmallVector<std::int64_t> dims; + mlir::Type type = arrayType; while (auto t = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(type)) { dims.push_back(t.getNumElements()); type = t.getElementType(); } + // Avoid generating long insert chain that are very slow to fold back + // (which is required in globals when later generating LLVM IR). Attempt to + // fold the inserted element value to an attribute and build an ArrayAttr + // for the resulting array. + if (range.isFullRange()) { + if (mlir::Attribute cst = + fir::tryFoldingLLVMInsertChain(adaptor.getVal(), rewriter)) { + mlir::Attribute dimVal = cst; + for (auto dim : llvm::reverse(dims)) { + // Use std::vector in case the number of elements is big. + std::vector<mlir::Attribute> elements(dim, dimVal); + dimVal = mlir::ArrayAttr::get(range.getContext(), elements); + } + // Replace insert chain with constant. + rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(range, arrayType, + dimVal); + return mlir::success(); + } + } + + // The inserted value cannot be folded to an attribute, turn the + // insert_range into an llvm.insertvalue chain. llvm::SmallVector<std::int64_t> lBounds; llvm::SmallVector<std::int64_t> uBounds; @@ -2434,8 +2458,8 @@ struct InsertOnRangeOpConversion auto &subscripts = lBounds; auto loc = range.getLoc(); - mlir::Value lastOp = adaptor.getOperands()[0]; - mlir::Value insertVal = adaptor.getOperands()[1]; + mlir::Value lastOp = adaptor.getSeq(); + mlir::Value insertVal = adaptor.getVal(); while (subscripts != uBounds) { lastOp = rewriter.create<mlir::LLVM::InsertValueOp>( @@ -3131,7 +3155,7 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> { // initialization is on the full range. auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>(); for (auto insertOp : insertOnRangeOps) { - if (isFullRange(insertOp.getCoor(), insertOp.getType())) { + if (insertOp.isFullRange()) { auto seqTyAttr = convertType(insertOp.getType()); auto *op = insertOp.getVal().getDefiningOp(); auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op); @@ -3161,22 +3185,7 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> { return mlir::success(); } - bool isFullRange(mlir::DenseIntElementsAttr indexes, - fir::SequenceType seqTy) const { - auto extents = seqTy.getShape(); - if (indexes.size() / 2 != static_cast<int64_t>(extents.size())) - return false; - auto cur_index = indexes.value_begin<int64_t>(); - for (unsigned i = 0; i < indexes.size(); i += 2) { - if (*(cur_index++) != 0) - return false; - if (*(cur_index++) != extents[i / 2] - 1) - return false; - } - return true; - } - - // TODO: String comparaison should be avoided. Replace linkName with an + // TODO: String comparisons should be avoided. Replace linkName with an // enumeration. mlir::LLVM::Linkage convertLinkage(std::optional<llvm::StringRef> optLinkage) const { diff --git a/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp b/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp new file mode 100644 index 0000000000000..0fc8697b735cf --- /dev/null +++ b/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp @@ -0,0 +1,204 @@ +//===-- LLVMInsertChainFolder.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "flang-insert-folder" + +#include <deque> + +namespace { +// Helper class to construct the attribute elements of an aggregate value being +// folded without creating a full mlir::Attribute representation for each step +// of the insert value chain, which would both be expensive in terms of +// compilation time and memory (since the intermediate Attribute would survive, +// unused, inside the mlir context). +class InsertChainBackwardFolder { + // Type for the current value of an element of the aggregate value being + // constructed by the insert chain. + // At any point of the insert chain, the value of an element is either: + // - nullptr: not yet known, the insert has not yet been seen. + // - an mlir::Attribute: the element is fully defined. + // - a nested InsertChainBackwardFolder: the element is itself an aggregate + // and its sub-elements have been partially defined (insert with mutliple + // indices have been seen). + + // The insertion folder assumes backward walk of the insert chain. Once an + // element or sub-element has been defined, it is not overriden by new + // insertions (last insert wins). + using InFlightValue = + llvm::PointerUnion<mlir::Attribute, InsertChainBackwardFolder *>; + +public: + InsertChainBackwardFolder( + mlir::Type type, std::deque<InsertChainBackwardFolder> *folderStorage) + : values(getNumElements(type), mlir::Attribute{}), + folderStorage{folderStorage}, type{type} {} + + /// Push + bool pushValue(mlir::Attribute val, llvm::ArrayRef<int64_t> at); + + mlir::Attribute finalize(mlir::Attribute defaultFieldValue); + +private: + static int64_t getNumElements(mlir::Type type) { + if (auto structTy = + llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type)) + return structTy.getBody().size(); + if (auto arrayTy = + llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type)) + return arrayTy.getNumElements(); + return 0; + } + + static mlir::Type getSubElementType(mlir::Type type, int64_t field) { + if (auto arrayTy = + llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type)) + return arrayTy.getElementType(); + if (auto structTy = + llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type)) + return structTy.getBody()[field]; + return {}; + } + + // Current element value of the aggregate value being built. + llvm::SmallVector<InFlightValue> values; + // std::deque is used to allocate storage for nested list and guarantee the + // stability of the InsertChainBackwardFolder* used as element value. + std::deque<InsertChainBackwardFolder> *folderStorage; + // Type of the aggregate value being built. + mlir::Type type; +}; +} // namespace + +// Helper to fold the value being inserted by an llvm.insert_value. +// This may call tryFoldingLLVMInsertChain if the value is an aggregate and +// was itself constructed by a different insert chain. +static mlir::Attribute getAttrIfConstant(mlir::Value val, + mlir::OpBuilder &rewriter) { + if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>()) + return cst.getValue(); + if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) + return fir::tryFoldingLLVMInsertChain(val, rewriter); + if (val.getDefiningOp<mlir::LLVM::ZeroOp>()) + return mlir::LLVM::ZeroAttr::get(val.getContext()); + if (val.getDefiningOp<mlir::LLVM::UndefOp>()) + return mlir::LLVM::UndefAttr::get(val.getContext()); + if (mlir::Operation *op = val.getDefiningOp()) { + unsigned resNum = llvm::cast<mlir::OpResult>(val).getResultNumber(); + llvm::SmallVector<mlir::Value> results; + if (mlir::succeeded(rewriter.tryFold(op, results)) && + results.size() > resNum) { + if (auto cst = results[resNum].getDefiningOp<mlir::LLVM::ConstantOp>()) + return cst.getValue(); + } + } + if (auto trunc = val.getDefiningOp<mlir::LLVM::TruncOp>()) + if (auto attr = getAttrIfConstant(trunc.getArg(), rewriter)) + if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(attr)) + return mlir::IntegerAttr::get(trunc.getType(), intAttr.getInt()); + LLVM_DEBUG(llvm::dbgs() << "cannot fold insert value operand: " << val + << "\n"); + return {}; +} + +mlir::Attribute +InsertChainBackwardFolder::finalize(mlir::Attribute defaultFieldValue) { + std::vector<mlir::Attribute> attrs; + attrs.reserve(values.size()); + for (InFlightValue &inFlight : values) { + if (!inFlight) { + attrs.push_back(defaultFieldValue); + } else if (auto attr = llvm::dyn_cast<mlir::Attribute>(inFlight)) { + attrs.push_back(attr); + } else { + auto *inFlightList = llvm::cast<InsertChainBackwardFolder *>(inFlight); + attrs.push_back(inFlightList->finalize(defaultFieldValue)); + } + } + return mlir::ArrayAttr::get(type.getContext(), attrs); +} + +bool InsertChainBackwardFolder::pushValue(mlir::Attribute val, + llvm::ArrayRef<int64_t> at) { + if (at.size() == 0 || at[0] >= static_cast<int64_t>(values.size())) + return false; + InFlightValue &inFlight = values[at[0]]; + if (!inFlight) { + if (at.size() == 1) { + inFlight = val; + return true; + } + // This is the first insert to a nested field. Create a + // InsertChainBackwardFolder for the current element value. + InsertChainBackwardFolder &inFlightList = folderStorage->emplace_back( + getSubElementType(type, at[0]), folderStorage); + inFlight = &inFlightList; + return inFlightList.pushValue(val, at.drop_front()); + } + // Keep last inserted value if already set. + if (llvm::isa<mlir::Attribute>(inFlight)) + return true; + auto *inFlightList = llvm::cast<InsertChainBackwardFolder *>(inFlight); + if (at.size() == 1) { + if (!llvm::isa<mlir::LLVM::ZeroAttr, mlir::LLVM::UndefAttr>(val)) { + LLVM_DEBUG(llvm::dbgs() + << "insert chain sub-element partially overwritten initial " + "value is not zero or undef\n"); + return false; + } + inFlight = inFlightList->finalize(val); + return true; + } + return inFlightList->pushValue(val, at.drop_front()); +} + +mlir::Attribute fir::tryFoldingLLVMInsertChain(mlir::Value val, + mlir::OpBuilder &rewriter) { + if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>()) + return cst.getValue(); + if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) { + LLVM_DEBUG(llvm::dbgs() << "trying to fold insert chain:" << val << "\n"); + if (auto structTy = + llvm::dyn_cast<mlir::LLVM::LLVMStructType>(insert.getType())) { + mlir::LLVM::InsertValueOp currentInsert = insert; + mlir::LLVM::InsertValueOp lastInsert; + std::deque<InsertChainBackwardFolder> folderStorage; + InsertChainBackwardFolder inFlightList(structTy, &folderStorage); + while (currentInsert) { + mlir::Attribute attr = + getAttrIfConstant(currentInsert.getValue(), rewriter); + if (!attr) + return {}; + if (!inFlightList.pushValue(attr, currentInsert.getPosition())) + return {}; + lastInsert = currentInsert; + currentInsert = currentInsert.getContainer() + .getDefiningOp<mlir::LLVM::InsertValueOp>(); + } + mlir::Attribute defaultVal; + if (lastInsert) { + if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::ZeroOp>()) + defaultVal = mlir::LLVM::ZeroAttr::get(val.getContext()); + else if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::UndefOp>()) + defaultVal = mlir::LLVM::UndefAttr::get(val.getContext()); + } + if (!defaultVal) { + LLVM_DEBUG(llvm::dbgs() + << "insert chain initial value is not Zero or Undef\n"); + return {}; + } + return inFlightList.finalize(defaultVal); + } + } + return {}; +} diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index d85b38c467857..e12af7782a578 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -2365,6 +2365,21 @@ llvm::LogicalResult fir::InsertOnRangeOp::verify() { return mlir::success(); } +bool fir::InsertOnRangeOp::isFullRange() { + auto extents = getType().getShape(); + mlir::DenseIntElementsAttr indexes = getCoor(); + if (indexes.size() / 2 != static_cast<int64_t>(extents.size())) + return false; + auto cur_index = indexes.value_begin<int64_t>(); + for (unsigned i = 0; i < indexes.size(); i += 2) { + if (*(cur_index++) != 0) + return false; + if (*(cur_index++) != extents[i / 2] - 1) + return false; + } + return true; +} + //===----------------------------------------------------------------------===// // InsertValueOp //===----------------------------------------------------------------------===// diff --git a/flang/test/Fir/convert-and-fold-insert-on-range.fir b/flang/test/Fir/convert-and-fold-insert-on-range.fir new file mode 100644 index 0000000000000..df18614d80b63 --- /dev/null +++ b/flang/test/Fir/convert-and-fold-insert-on-range.fir @@ -0,0 +1,33 @@ +// Test codegen of constant insert_on_range without symbol reference into mlir.constant. +// RUN: fir-opt --cg-rewrite --split-input-file --fir-to-llvm-ir %s | FileCheck %s + +module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i1 = dense<8> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, "dlti.endianness" = "little", "dlti.mangling_mode" = "e", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} { + fir.global @derived_array : !fir.array<2x!fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>> { + %c0 = arith.constant 0 : index + %0 = fir.undefined !fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}> + %1 = fir.zero_bits !fir.heap<!fir.array<?xf64>> + %2 = fir.shape %c0 : (index) -> !fir.shape<1> + %3 = fir.embox %1(%2) : (!fir.heap<!fir.array<?xf64>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf64>>> + %4 = fir.insert_value %0, %3, ["comp", !fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>] : (!fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>, !fir.box<!fir.heap<!fir.array<?xf64>>>) -> !fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}> + %5 = fir.undefined !fir.array<2x!fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>> + %6 = fir.insert_on_range %5, %4 from (0) to (1) : (!fir.array<2x!fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>>, !fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>) -> !fir.array<2x!fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>> + fir.has_value %6 : !fir.array<2x!fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>> + } +} + +//CHECK-LABEL: llvm.mlir.global external @derived_array() +//CHECK: %[[CST:.*]] = llvm.mlir.constant([ +//CHECK-SAME: [ +//CHECK-SAME: [#llvm.zero, 8, 20240719 : i32, 1 : i8, 28 : i8, 2 : i8, 0 : i8, +//CHECK-SAME: [ +//CHECK-SAME: [1, 0 : index, 8] +//CHECK-SAME: ] +//CHECK-SAME: ], +//CHECK-SAME: [ +//CHECK-SAME: [#llvm.zero, 8, 20240719 : i32, 1 : i8, 28 : i8, 2 : i8, 0 : i8, +//CHECK-SAME: [ +//CHECK-SAME: [1, 0 : index, 8] +//CHECK-SAME: ] +//CHECK-SAME: ]) : +//CHECK-SAME: !llvm.array<2 x struct<"sometype", (struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>)>> +//CHECK: llvm.return %[[CST]] : !llvm.array<2 x struct<"sometype", (struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>)>> >From 796a1e0269baf1c77ffabf47a8fa155356bc9096 Mon Sep 17 00:00:00 2001 From: Jean Perier <jper...@nvidia.com> Date: Mon, 19 May 2025 01:37:14 -0700 Subject: [PATCH 2/2] use map_to_vector and FailureOr --- .../Optimizer/CodeGen/LLVMInsertChainFolder.h | 7 ++- flang/lib/Optimizer/CodeGen/CodeGen.cpp | 7 +-- .../CodeGen/LLVMInsertChainFolder.cpp | 54 ++++++++++--------- 3 files changed, 39 insertions(+), 29 deletions(-) diff --git a/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h b/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h index d577c4c0fa70b..321bda91aa6fe 100644 --- a/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h +++ b/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h @@ -12,6 +12,8 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Support/LogicalResult.h" + namespace mlir { class Attribute; class OpBuilder; @@ -26,6 +28,7 @@ namespace fir { /// or cannot be represented as an Attribute. The operations are not deleted, /// but some llvm.insertvalue value operands may be folded with the builder on /// the way. -mlir::Attribute tryFoldingLLVMInsertChain(mlir::Value insertChainResult, - mlir::OpBuilder &builder); +llvm::FailureOr<mlir::Attribute> +tryFoldingLLVMInsertChain(mlir::Value insertChainResult, + mlir::OpBuilder &builder); } // namespace fir diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index ed76a77ced047..70c90fae34086 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -2428,9 +2428,10 @@ struct InsertOnRangeOpConversion // fold the inserted element value to an attribute and build an ArrayAttr // for the resulting array. if (range.isFullRange()) { - if (mlir::Attribute cst = - fir::tryFoldingLLVMInsertChain(adaptor.getVal(), rewriter)) { - mlir::Attribute dimVal = cst; + llvm::FailureOr<mlir::Attribute> cst = + fir::tryFoldingLLVMInsertChain(adaptor.getVal(), rewriter); + if (llvm::succeeded(cst)) { + mlir::Attribute dimVal = *cst; for (auto dim : llvm::reverse(dims)) { // Use std::vector in case the number of elements is big. std::vector<mlir::Attribute> elements(dim, dimVal); diff --git a/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp b/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp index 0fc8697b735cf..5b522f2647916 100644 --- a/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp +++ b/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp @@ -67,7 +67,7 @@ class InsertChainBackwardFolder { if (auto structTy = llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type)) return structTy.getBody()[field]; - return {}; + return nullptr; } // Current element value of the aggregate value being built. @@ -83,12 +83,18 @@ class InsertChainBackwardFolder { // Helper to fold the value being inserted by an llvm.insert_value. // This may call tryFoldingLLVMInsertChain if the value is an aggregate and // was itself constructed by a different insert chain. +// Returns a nullptr Attribute if the value could not be folded. static mlir::Attribute getAttrIfConstant(mlir::Value val, mlir::OpBuilder &rewriter) { if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>()) return cst.getValue(); - if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) - return fir::tryFoldingLLVMInsertChain(val, rewriter); + if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) { + llvm::FailureOr<mlir::Attribute> attr = + fir::tryFoldingLLVMInsertChain(val, rewriter); + if (succeeded(attr)) + return *attr; + return nullptr; + } if (val.getDefiningOp<mlir::LLVM::ZeroOp>()) return mlir::LLVM::ZeroAttr::get(val.getContext()); if (val.getDefiningOp<mlir::LLVM::UndefOp>()) @@ -108,23 +114,20 @@ static mlir::Attribute getAttrIfConstant(mlir::Value val, return mlir::IntegerAttr::get(trunc.getType(), intAttr.getInt()); LLVM_DEBUG(llvm::dbgs() << "cannot fold insert value operand: " << val << "\n"); - return {}; + return nullptr; } mlir::Attribute InsertChainBackwardFolder::finalize(mlir::Attribute defaultFieldValue) { - std::vector<mlir::Attribute> attrs; - attrs.reserve(values.size()); - for (InFlightValue &inFlight : values) { - if (!inFlight) { - attrs.push_back(defaultFieldValue); - } else if (auto attr = llvm::dyn_cast<mlir::Attribute>(inFlight)) { - attrs.push_back(attr); - } else { - auto *inFlightList = llvm::cast<InsertChainBackwardFolder *>(inFlight); - attrs.push_back(inFlightList->finalize(defaultFieldValue)); - } - } + llvm::SmallVector<mlir::Attribute> attrs = llvm::map_to_vector( + values, [&](InFlightValue inFlight) -> mlir::Attribute { + if (!inFlight) + return defaultFieldValue; + if (auto attr = llvm::dyn_cast<mlir::Attribute>(inFlight)) + return attr; + return llvm::cast<InsertChainBackwardFolder *>(inFlight)->finalize( + defaultFieldValue); + }); return mlir::ArrayAttr::get(type.getContext(), attrs); } @@ -140,8 +143,11 @@ bool InsertChainBackwardFolder::pushValue(mlir::Attribute val, } // This is the first insert to a nested field. Create a // InsertChainBackwardFolder for the current element value. - InsertChainBackwardFolder &inFlightList = folderStorage->emplace_back( - getSubElementType(type, at[0]), folderStorage); + mlir::Type subType = getSubElementType(type, at[0]); + if (!subType) + return false; + InsertChainBackwardFolder &inFlightList = + folderStorage->emplace_back(subType, folderStorage); inFlight = &inFlightList; return inFlightList.pushValue(val, at.drop_front()); } @@ -162,8 +168,8 @@ bool InsertChainBackwardFolder::pushValue(mlir::Attribute val, return inFlightList->pushValue(val, at.drop_front()); } -mlir::Attribute fir::tryFoldingLLVMInsertChain(mlir::Value val, - mlir::OpBuilder &rewriter) { +llvm::FailureOr<mlir::Attribute> +fir::tryFoldingLLVMInsertChain(mlir::Value val, mlir::OpBuilder &rewriter) { if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>()) return cst.getValue(); if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) { @@ -178,9 +184,9 @@ mlir::Attribute fir::tryFoldingLLVMInsertChain(mlir::Value val, mlir::Attribute attr = getAttrIfConstant(currentInsert.getValue(), rewriter); if (!attr) - return {}; + return llvm::failure(); if (!inFlightList.pushValue(attr, currentInsert.getPosition())) - return {}; + return llvm::failure(); lastInsert = currentInsert; currentInsert = currentInsert.getContainer() .getDefiningOp<mlir::LLVM::InsertValueOp>(); @@ -195,10 +201,10 @@ mlir::Attribute fir::tryFoldingLLVMInsertChain(mlir::Value val, if (!defaultVal) { LLVM_DEBUG(llvm::dbgs() << "insert chain initial value is not Zero or Undef\n"); - return {}; + return llvm::failure(); } return inFlightList.finalize(defaultVal); } } - return {}; + return llvm::failure(); } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits