================ @@ -0,0 +1,204 @@ +//===-- LLVMInsertChainFolder.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "flang-insert-folder" + +#include <deque> + +namespace { +// Helper class to construct the attribute elements of an aggregate value being +// folded without creating a full mlir::Attribute representation for each step +// of the insert value chain, which would both be expensive in terms of +// compilation time and memory (since the intermediate Attribute would survive, +// unused, inside the mlir context). +class InsertChainBackwardFolder { + // Type for the current value of an element of the aggregate value being + // constructed by the insert chain. + // At any point of the insert chain, the value of an element is either: + // - nullptr: not yet known, the insert has not yet been seen. + // - an mlir::Attribute: the element is fully defined. + // - a nested InsertChainBackwardFolder: the element is itself an aggregate + // and its sub-elements have been partially defined (insert with mutliple + // indices have been seen). + + // The insertion folder assumes backward walk of the insert chain. Once an + // element or sub-element has been defined, it is not overriden by new + // insertions (last insert wins). + using InFlightValue = + llvm::PointerUnion<mlir::Attribute, InsertChainBackwardFolder *>; + +public: + InsertChainBackwardFolder( + mlir::Type type, std::deque<InsertChainBackwardFolder> *folderStorage) + : values(getNumElements(type), mlir::Attribute{}), + folderStorage{folderStorage}, type{type} {} + + /// Push + bool pushValue(mlir::Attribute val, llvm::ArrayRef<int64_t> at); + + mlir::Attribute finalize(mlir::Attribute defaultFieldValue); + +private: + static int64_t getNumElements(mlir::Type type) { + if (auto structTy = + llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type)) + return structTy.getBody().size(); + if (auto arrayTy = + llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type)) + return arrayTy.getNumElements(); + return 0; + } + + static mlir::Type getSubElementType(mlir::Type type, int64_t field) { + if (auto arrayTy = + llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type)) + return arrayTy.getElementType(); + if (auto structTy = + llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type)) + return structTy.getBody()[field]; + return {}; + } + + // Current element value of the aggregate value being built. + llvm::SmallVector<InFlightValue> values; + // std::deque is used to allocate storage for nested list and guarantee the + // stability of the InsertChainBackwardFolder* used as element value. + std::deque<InsertChainBackwardFolder> *folderStorage; + // Type of the aggregate value being built. + mlir::Type type; +}; +} // namespace + +// Helper to fold the value being inserted by an llvm.insert_value. +// This may call tryFoldingLLVMInsertChain if the value is an aggregate and +// was itself constructed by a different insert chain. +static mlir::Attribute getAttrIfConstant(mlir::Value val, + mlir::OpBuilder &rewriter) { + if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>()) + return cst.getValue(); + if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) + return fir::tryFoldingLLVMInsertChain(val, rewriter); + if (val.getDefiningOp<mlir::LLVM::ZeroOp>()) + return mlir::LLVM::ZeroAttr::get(val.getContext()); + if (val.getDefiningOp<mlir::LLVM::UndefOp>()) + return mlir::LLVM::UndefAttr::get(val.getContext()); + if (mlir::Operation *op = val.getDefiningOp()) { + unsigned resNum = llvm::cast<mlir::OpResult>(val).getResultNumber(); + llvm::SmallVector<mlir::Value> results; + if (mlir::succeeded(rewriter.tryFold(op, results)) && + results.size() > resNum) { + if (auto cst = results[resNum].getDefiningOp<mlir::LLVM::ConstantOp>()) + return cst.getValue(); + } + } + if (auto trunc = val.getDefiningOp<mlir::LLVM::TruncOp>()) + if (auto attr = getAttrIfConstant(trunc.getArg(), rewriter)) + if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(attr)) + return mlir::IntegerAttr::get(trunc.getType(), intAttr.getInt()); + LLVM_DEBUG(llvm::dbgs() << "cannot fold insert value operand: " << val + << "\n"); + return {}; +} + +mlir::Attribute +InsertChainBackwardFolder::finalize(mlir::Attribute defaultFieldValue) { + std::vector<mlir::Attribute> attrs; + attrs.reserve(values.size()); + for (InFlightValue &inFlight : values) { + if (!inFlight) { + attrs.push_back(defaultFieldValue); + } else if (auto attr = llvm::dyn_cast<mlir::Attribute>(inFlight)) { + attrs.push_back(attr); + } else { + auto *inFlightList = llvm::cast<InsertChainBackwardFolder *>(inFlight); + attrs.push_back(inFlightList->finalize(defaultFieldValue)); + } + } + return mlir::ArrayAttr::get(type.getContext(), attrs); +} + +bool InsertChainBackwardFolder::pushValue(mlir::Attribute val, + llvm::ArrayRef<int64_t> at) { + if (at.size() == 0 || at[0] >= static_cast<int64_t>(values.size())) + return false; + InFlightValue &inFlight = values[at[0]]; + if (!inFlight) { + if (at.size() == 1) { + inFlight = val; + return true; + } + // This is the first insert to a nested field. Create a + // InsertChainBackwardFolder for the current element value. + InsertChainBackwardFolder &inFlightList = folderStorage->emplace_back( + getSubElementType(type, at[0]), folderStorage); + inFlight = &inFlightList; + return inFlightList.pushValue(val, at.drop_front()); + } + // Keep last inserted value if already set. + if (llvm::isa<mlir::Attribute>(inFlight)) + return true; + auto *inFlightList = llvm::cast<InsertChainBackwardFolder *>(inFlight); + if (at.size() == 1) { + if (!llvm::isa<mlir::LLVM::ZeroAttr, mlir::LLVM::UndefAttr>(val)) { + LLVM_DEBUG(llvm::dbgs() + << "insert chain sub-element partially overwritten initial " + "value is not zero or undef\n"); + return false; + } + inFlight = inFlightList->finalize(val); + return true; + } + return inFlightList->pushValue(val, at.drop_front()); +} + +mlir::Attribute fir::tryFoldingLLVMInsertChain(mlir::Value val, + mlir::OpBuilder &rewriter) { + if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>()) + return cst.getValue(); + if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) { + LLVM_DEBUG(llvm::dbgs() << "trying to fold insert chain:" << val << "\n"); + if (auto structTy = + llvm::dyn_cast<mlir::LLVM::LLVMStructType>(insert.getType())) { + mlir::LLVM::InsertValueOp currentInsert = insert; + mlir::LLVM::InsertValueOp lastInsert; + std::deque<InsertChainBackwardFolder> folderStorage; + InsertChainBackwardFolder inFlightList(structTy, &folderStorage); + while (currentInsert) { + mlir::Attribute attr = + getAttrIfConstant(currentInsert.getValue(), rewriter); + if (!attr) + return {}; + if (!inFlightList.pushValue(attr, currentInsert.getPosition())) + return {}; + lastInsert = currentInsert; + currentInsert = currentInsert.getContainer() + .getDefiningOp<mlir::LLVM::InsertValueOp>(); + } + mlir::Attribute defaultVal; + if (lastInsert) { + if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::ZeroOp>()) + defaultVal = mlir::LLVM::ZeroAttr::get(val.getContext()); + else if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::UndefOp>()) + defaultVal = mlir::LLVM::UndefAttr::get(val.getContext()); + } + if (!defaultVal) { + LLVM_DEBUG(llvm::dbgs() + << "insert chain initial value is not Zero or Undef\n"); + return {}; + } + return inFlightList.finalize(defaultVal); + } + } + return {}; ---------------- jeanPerier wrote:
Yes, `FailureOr<Attribute>` is clearer but it is heavier (mainly because you cannot do `if (FailureOr<T> x = ...)` which is a structured style I prefer when possible). I updated the API to use it. https://github.com/llvm/llvm-project/pull/140268 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits