llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang-codegen Author: Julian Brown (jtb20) <details> <summary>Changes</summary> This patch is an early outline for a rework of several mapping features, intended to support 'containing structures' in the middle of expressions as well as at the top level (https://github.com/llvm/llvm-project/issues/141042). Key ideas are: - struct information is gathered using several pre-passes before `generateInfoForComponentList`. - "PartialStruct" is turned into a map, keyed on the "effective base" of each containing structure in a set of expressions in OpenMP 'map' clauses. - the reverse iterator over component lists (visiting the base decl, then walking out to the full expression) has a new 'ComponentListRefPtrPteeIterator' adapter that (a) visits reference-type list components twice, and (b) provides a few useful utility methods. The current state is that a couple of tests work up to a point, but I'm hitting problems with the runtime that will probably be helped by the in-progress patches to support ATTACH operations. This is obviously all full of debug code and not at all ready for review! Posting FYI. --- Patch is 78.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153672.diff 1 Files Affected: - (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+1281-211) ``````````diff diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp index f1698a0bec373..3587096f8c6ec 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -18,25 +18,32 @@ #include "CGRecordLayout.h" #include "CodeGenFunction.h" #include "TargetInfo.h" +#include "clang-c/Index.h" #include "clang/AST/APValue.h" +#include "clang/AST/ASTContext.h" #include "clang/AST/Attr.h" #include "clang/AST/Decl.h" +#include "clang/AST/ExprOpenMP.h" #include "clang/AST/OpenMPClause.h" #include "clang/AST/StmtOpenMP.h" #include "clang/AST/StmtVisitor.h" #include "clang/Basic/OpenMPKinds.h" #include "clang/Basic/SourceManager.h" +#include "clang/CodeGen/CodeGenABITypes.h" #include "clang/CodeGen/ConstantInitBuilder.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/FoldingSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Value.h" #include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/TypeSize.h" #include "llvm/Support/raw_ostream.h" #include <cassert> #include <cstdint> @@ -6755,6 +6762,51 @@ llvm::Value *CGOpenMPRuntime::emitNumThreadsForTargetDirective( return NumThreadsVal; } +class EffectiveBaseMapKey { + llvm::FoldingSetNodeID ID; + bool Indirect = false; + +public: + EffectiveBaseMapKey(ASTContext &Ctx, const Expr *EB, bool Ind) { + EB->Profile(ID, Ctx, /*Canonical=*/true); + Indirect = Ind; + } + + EffectiveBaseMapKey(llvm::FoldingSetNodeID FromID) : ID(FromID) { } + + llvm::FoldingSetNodeID getID() const { return ID; } + bool getIndirect() const { return Indirect; } +}; + +template <> struct llvm::DenseMapInfo<EffectiveBaseMapKey> { + + static EffectiveBaseMapKey getEmptyKey() { + llvm::FoldingSetNodeID ID; + ID.AddInteger(std::numeric_limits<unsigned>::max()); + return EffectiveBaseMapKey(ID); + } + + static EffectiveBaseMapKey getTombstoneKey() { + llvm::FoldingSetNodeID ID; + for (unsigned I = 0; I < sizeof(ID) / sizeof(unsigned); ++I) { + ID.AddInteger(std::numeric_limits<unsigned>::max()); + } + return EffectiveBaseMapKey(ID); + } + + static unsigned getHashValue(const EffectiveBaseMapKey &Val) { + auto ID = Val.getID(); + ID.AddBoolean(Val.getIndirect()); + return ID.ComputeHash(); + } + + static bool isEqual(const EffectiveBaseMapKey &LHS, + const EffectiveBaseMapKey &RHS) { + return LHS.getID() == RHS.getID() && + LHS.getIndirect() == RHS.getIndirect(); + } +}; + namespace { LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE(); @@ -6787,7 +6839,6 @@ class MappableExprsHandler { public: MappingExprInfo(const ValueDecl *MapDecl, const Expr *MapExpr = nullptr) : MapDecl(MapDecl), MapExpr(MapExpr) {} - const ValueDecl *getMapDecl() const { return MapDecl; } const Expr *getMapExpr() const { return MapExpr; } }; @@ -6825,22 +6876,50 @@ class MappableExprsHandler { } }; + struct MappableExprMetadata { + OMPClauseMappableExprCommon::MappableExprComponentListRef Components; + const MapData *MD = nullptr; + bool CompleteExpression = false; + Address Base = Address::invalid(); + Address Pointer = Address::invalid(); + + //MappableExprMetadata() {} + + /*MappableExprMetadata(OMPClauseMappableExprCommon::MappableExprComponentListRef Components, + const MapData *MD, bool CompleteExpression) + : Components(Components), MD(MD), + CompleteExpression(CompleteExpression) {} + + MappableExprMetadata(MappableExprMetadata &Other) { + + }*/ + }; + + using ExprComponentMap = llvm::MapVector<EffectiveBaseMapKey, MappableExprMetadata>; + /// Map between a struct and the its lowest & highest elements which have been /// mapped. /// [ValueDecl *] --> {LE(FieldIndex, Pointer), /// HE(FieldIndex, Pointer)} struct StructRangeInfoTy { - MapCombinedInfoTy PreliminaryMapData; + //MapCombinedInfoTy PreliminaryMapData; + const Expr *BaseExpr = nullptr; + ExprComponentMap ChildComponents; + unsigned MemberDepth = -1u; std::pair<unsigned /*FieldIndex*/, Address /*Pointer*/> LowestElem = { 0, Address::invalid()}; std::pair<unsigned /*FieldIndex*/, Address /*Pointer*/> HighestElem = { 0, Address::invalid()}; Address Base = Address::invalid(); + Address BaseAddr = Address::invalid(); Address LB = Address::invalid(); bool IsArraySection = false; - bool HasCompleteRecord = false; + const MapData *ContainingStructMap = nullptr; }; + // A map from effective base addresses to struct range info for that base. + using PartialStructMap = llvm::DenseMap<EffectiveBaseMapKey, StructRangeInfoTy>; + private: /// Kind that defines how a device pointer has to be returned. struct MapInfo { @@ -7108,21 +7187,25 @@ class MappableExprsHandler { Address BP, Address LB, bool IsNonContiguous, uint64_t DimSize) : CGF(CGF), CombinedInfo(CombinedInfo), Flags(Flags), MapDecl(MapDecl), - MapExpr(MapExpr), BP(BP), LB(LB), IsNonContiguous(IsNonContiguous), - DimSize(DimSize) {} + MapExpr(MapExpr), BP(BP), IsNonContiguous(IsNonContiguous), + DimSize(DimSize), LB(LB) {} void processField( const OMPClauseMappableExprCommon::MappableComponent &MC, + llvm::DenseMap<const FieldDecl *, uint64_t> &Layout, const FieldDecl *FD, llvm::function_ref<LValue(CodeGenFunction &, const MemberExpr *)> EmitMemberExprBase) { - const RecordDecl *RD = FD->getParent(); - const ASTRecordLayout &RL = CGF.getContext().getASTRecordLayout(RD); - uint64_t FieldOffset = RL.getFieldOffset(FD->getFieldIndex()); + uint64_t FieldOffset = CGF.getContext().toBits(CharUnits::fromQuantity(Layout[FD])); uint64_t FieldSize = CGF.getContext().getTypeSize(FD->getType().getCanonicalType()); Address ComponentLB = Address::invalid(); + fprintf(stderr, "Process field: "); + MC.getAssociatedExpression()->dumpPretty(CGF.getContext()); + fprintf(stderr, " offset: %d index: %d cursor: %d\n", (int) FieldOffset, (int) FD->getFieldIndex(), + (int) Cursor); + if (FD->getType()->isLValueReferenceType()) { const auto *ME = cast<MemberExpr>(MC.getAssociatedExpression()); LValue BaseLVal = EmitMemberExprBase(CGF, ME); @@ -7133,8 +7216,6 @@ class MappableExprsHandler { CGF.EmitOMPSharedLValue(MC.getAssociatedExpression()).getAddress(); } - if (!LastParent) - LastParent = RD; if (FD->getParent() == LastParent) { if (FD->getFieldIndex() != LastIndex + 1) copyUntilField(FD, ComponentLB); @@ -7156,13 +7237,11 @@ class MappableExprsHandler { copySizedChunk(LBPtr, Size); } - void copyUntilEnd(Address HB) { - if (LastParent) { - const ASTRecordLayout &RL = - CGF.getContext().getASTRecordLayout(LastParent); - if ((uint64_t)CGF.getContext().toBits(RL.getSize()) <= Cursor) - return; - } + void copyUntilEnd(Address HB, CharUnits TypeSize) { + fprintf(stderr, "copyUntilEnd: Cursor=%d TypeSize=%d\n", + (int) Cursor, (int) CGF.getContext().toBits(TypeSize)); + if ((uint64_t)CGF.getContext().toBits(TypeSize) <= Cursor) + return; llvm::Value *LBPtr = LB.emitRawPointer(CGF); llvm::Value *Size = CGF.Builder.CreatePtrDiff( CGF.Int8Ty, CGF.Builder.CreateConstGEP(HB, 1).emitRawPointer(CGF), @@ -7184,6 +7263,544 @@ class MappableExprsHandler { } }; + /// Given a MemberExpr \c ME, find the containing structure as understood by + /// OpenMP (OpenMP 6.0, "2 Glossary"). + const Expr *getEffectiveBase(const MemberExpr *ME, const MemberExpr **IME, bool &Ind) const { + const Expr *Base = ME->getBase()->IgnoreParenImpCasts(); + + Ind = false; + + /*fprintf(stderr, "getEffectiveBase, input="); + ME->dumpPretty(CGF.getContext()); + fprintf(stderr, "\nbase="); + Base->dumpPretty(CGF.getContext()); + fprintf(stderr, "\n");*/ + + // Strip off any outer "." member accesses first + while (const auto *MEB = dyn_cast<MemberExpr>(Base)) { + if (ME->isArrow() || Base->getType()->isReferenceType()) { + break; + } else { + ME = MEB; + if (IME) + *IME = ME; + Base = ME->getBase()->IgnoreParenImpCasts(); + } + } + + /*fprintf(stderr, "now base="); + Base->dumpPretty(CGF.getContext()); + fprintf(stderr, "\n");*/ + + if (ME->isArrow() || Base->getType()->isReferenceType()) { + Ind = true; + return Base; + } + + return indirectOnce(Base, Ind); + } + + // Iterate a component list from the base of an expression to the complete + // expression. Components which are references are visited twice: firstly + // as the pointer, second as the pointee. + struct ComponentListRefPtrPteeIterator { + using iterator_category = std::forward_iterator_tag; + using value_type = const OMPClauseMappableExprCommon::MappableComponent; + using difference_type = std::ptrdiff_t; + using pointer = const OMPClauseMappableExprCommon::MappableComponent*; + using reference = const OMPClauseMappableExprCommon::MappableComponent&; + + std::reverse_iterator<const OMPClauseMappableExprCommon::MappableComponent *> Pos; + bool RefPtee = false; + const ValueDecl *BaseDecl = nullptr; + // We repeat on references -- this is the position in the underlying list. + unsigned ComponentPos = 0; + + ComponentListRefPtrPteeIterator(std::reverse_iterator<const OMPClauseMappableExprCommon::MappableComponent *> From) : Pos(From) + { } + + reference operator*() const { return *Pos; } + pointer operator->() { return &*Pos; } + + void setBaseDecl(const ValueDecl *Decl) { + BaseDecl = Decl; + } + + bool isRefPtee() { + return RefPtee; + } + + bool isRef() { + const OMPClauseMappableExprCommon::MappableComponent *Comp = &*Pos; + if (isa<MemberExpr>(Pos->getAssociatedExpression())) { + const ValueDecl *MapDecl = Comp->getAssociatedDeclaration(); + assert(MapDecl && "Expected associated declaration for member expr"); + return MapDecl->getType()->isLValueReferenceType(); + } + return false; + } + + bool isPointer(bool AllowDeref) { + const OMPClauseMappableExprCommon::MappableComponent *Comp = &*Pos; + const Expr *AE = Comp->getAssociatedExpression(); + const auto *OASE = dyn_cast<ArraySectionExpr>(AE); + bool IsPointer = + isa<OMPArrayShapingExpr>(AE) || + (OASE && ArraySectionExpr::getBaseOriginalType(OASE).getCanonicalType()->isAnyPointerType()) || + AE->getType()->isAnyPointerType(); + + if (AllowDeref) + return IsPointer; + else if (!IsPointer) + return false; + + if (const auto *UO = dyn_cast<UnaryOperator>(AE)) + return UO->getOpcode() != UO_Deref; + + return !isa<BinaryOperator>(AE); + } + + unsigned getComponentPos() { + return ComponentPos; + } + + ComponentListRefPtrPteeIterator& operator++() { + if (isRef()) { + if (!RefPtee) + RefPtee = true; + else { + RefPtee = false; + ++Pos; + } + } else { + RefPtee = false; + // This could skip to outermost MemberExprs (over "."). + ++Pos; + ++ComponentPos; + /*while (auto ME = dyn_cast<MemberExpr>(Pos->getAssociatedExpression()->IgnoreParenImpCasts())) { + if (ME->isArrow() || ME->getBase()->getType()->isReferenceType()) + break; + else + ++Pos; + }*/ + } + return *this; + } + + ComponentListRefPtrPteeIterator operator++(int) { + ComponentListRefPtrPteeIterator tmp = *this; + ++(*this); + return tmp; + } + + friend bool operator==(const ComponentListRefPtrPteeIterator& a, + const ComponentListRefPtrPteeIterator& b) { + return a.Pos == b.Pos && a.RefPtee == b.RefPtee; + } + + friend bool operator!=(const ComponentListRefPtrPteeIterator &a, + const ComponentListRefPtrPteeIterator &b) { + return a.Pos != b.Pos || a.RefPtee != b.RefPtee; + } + }; + + bool exprsEqual(Expr *One, Expr *Two) const { + if (One == nullptr || Two == nullptr) + return One == Two; + + if (One->getStmtClass() != Two->getStmtClass()) + return false; + + llvm::FoldingSetNodeID ProfOne, ProfTwo; + One->Profile(ProfOne, CGF.getContext(), true); + Two->Profile(ProfTwo, CGF.getContext(), true); + + return ProfOne == ProfTwo; + } + + const Expr *indirectOnce(const Expr *E, bool &Ind) const { + Ind = false; + + // Treat (*foo).bar the same as foo->bar + if (const auto *UO = dyn_cast<UnaryOperator>(E)) { + if (UO->getOpcode() == UO_Deref) { + Ind = true; + return UO->getSubExpr()->IgnoreParenImpCasts(); + } + } + + // Treat foo[0].bar the same as foo->bar + while (const auto *ASE = dyn_cast<ArraySubscriptExpr>(E)) { + const Expr *ArrayBase = ASE->getBase()->IgnoreParenImpCasts(); + const Expr *Index = ASE->getIdx(); + Expr::EvalResult Result; + if (!Index->EvaluateAsInt(Result, CGF.getContext())) { + return E; + } + llvm::APSInt ConstIndex = Result.Val.getInt(); + if (ConstIndex == 0) { + Ind = true; + E = ArrayBase; + } + if (!E->getType()->isArrayType()) + break; + } + + // Treat foo[:1].bar & foo[0:1].bar the same as foo->bar + if (const auto *ASecE = dyn_cast<ArraySectionExpr>(E)) { + const Expr *ArrayBase = ASecE->getBase()->IgnoreParenImpCasts(); + const Expr *LB = ASecE->getLowerBound(); + const Expr *Len = ASecE->getLength(); + bool LBZero = false, LenOne = false; + Expr::EvalResult Result; + if (!LB) { + LBZero = true; + } else if (LB->EvaluateAsInt(Result, CGF.getContext())) { + llvm::APSInt ConstLB = Result.Val.getInt(); + if (ConstLB == 0) + LBZero = true; + } + if (Len && Len->EvaluateAsInt(Result, CGF.getContext())) { + llvm::APSInt ConstLen = Result.Val.getInt(); + if (ConstLen == 1) { + LenOne = true; + } + } + if (LBZero && LenOne) { + Ind = true; + return ArrayBase; + } + } + + return E; + } + + bool componentsEqual(const OMPClauseMappableExprCommon::MappableComponent &One, + const OMPClauseMappableExprCommon::MappableComponent &Two) const { + if (One.isNonContiguous() != Two.isNonContiguous()) + return false; + + ValueDecl *DeclOne = One.getAssociatedDeclaration(); + ValueDecl *DeclTwo = Two.getAssociatedDeclaration(); + + if (DeclOne == nullptr || DeclTwo == nullptr) + return DeclOne == DeclTwo; + + if (DeclOne->getCanonicalDecl() != DeclTwo->getCanonicalDecl()) + return false; + + return exprsEqual(One.getAssociatedExpression(), + Two.getAssociatedExpression()); + } + + bool hasMemberExpr(const OMPClauseMappableExprCommon::MappableExprComponentListRef Components) const { + return llvm::any_of(Components, [](const OMPClauseMappableExprCommon::MappableComponent &M) { + return isa<MemberExpr>(M.getAssociatedExpression()); + }); + } + + void gatherStructDataForComponentList(const MapData &MD, OpenMPMapClauseKind MapType, ArrayRef<OpenMPMapModifierKind> MapModifiers, + OMPClauseMappableExprCommon::MappableExprComponentListRef Components, + MapCombinedInfoTy &CombinedInfo, PartialStructMap &PartialStructs, + bool IsImplicit, const ValueDecl *Mapper = nullptr, + const ValueDecl *BaseDecl = nullptr, const Expr *MapExpr = nullptr) const { + // Scan the components from the base to the complete expression. + auto CI = ComponentListRefPtrPteeIterator(Components.rbegin()); + auto CE = ComponentListRefPtrPteeIterator(Components.rend()); + auto I = CI; + + I.setBaseDecl(BaseDecl); + + Address BPP = Address::invalid(); + Address BP = Address::invalid(); + + const Expr *AssocExpr = I->getAssociatedExpression(); + /*const auto *AE = dyn_cast<ArraySubscriptExpr>(AssocExpr); + const auto *OASE = dyn_cast<ArraySectionExpr>(AssocExpr); + const auto *OAShE = dyn_cast<OMPArrayShapingExpr>(AssocExpr);*/ + + auto &&EmitMemberExprBase = [](CodeGenFunction &CGF, + const MemberExpr *E) { + const Expr *BaseExpr = E->getBase(); + // If this is s.x, emit s as an lvalue. If it is s->x, emit s as a + // scalar. + LValue BaseLV; + if (E->isArrow()) { + LValueBaseInfo BaseInfo; + TBAAAccessInfo TBAAInfo; + Address Addr = + CGF.EmitPointerWithAlignment(BaseExpr, &BaseInfo, &TBAAInfo); + QualType PtrTy = BaseExpr->getType()->getPointeeType(); + BaseLV = CGF.MakeAddrLValue(Addr, PtrTy, BaseInfo, TBAAInfo); + } else { + BaseLV = CGF.EmitOMPSharedLValue(BaseExpr); + } + return BaseLV; + }; + + if (!hasMemberExpr(Components)) + return; + + /* + if (isa<MemberExpr>(AssocExpr)) { + // The base is the 'this' pointer. The content of the pointer is going + // to be the base of the field being mapped. + BP = CGF.LoadCXXThisAddress(); + } else if ((AE && isa<CXXThisExpr>(AE->getBase()->IgnoreParenImpCasts())) || + (OASE && + isa<CXXThisExpr>(OASE->getBase()->IgnoreParenImpCasts()))) { + BP = CGF.EmitOMPSharedLValue(AssocExpr).getAddress(); + } else if (OAShE && + isa<CXXThisExpr>(OAShE->getBase()->IgnoreParenCasts())) { + BP = Address( + CGF.EmitScalarExpr(OAShE->getBase()), + CGF.ConvertTypeForMem(OAShE->getBase()->getType()->getPointeeType()), + CGF.getContext().getTypeAlignInChars(OAShE->getBase()->getType())); + } else {*/ + // The base is the reference to the variable. + // BP = &Var. + fprintf(stderr, "Init new BP from "); + AssocExpr->dumpPretty(CGF.getContext()); + fprintf(stderr, "\n"); + BP = CGF.EmitOMPSharedLValue(AssocExpr).getAddress(); + BPP = Address::invalid(); + QualType Ty = CI->getAssociatedDeclaration()->getType().getNonReferenceType(); + if (Ty->isAnyPointerType()) { + BPP = BP; + BP = CGF.EmitLoadOfPointer(BP, Ty->castAs<PointerType>()); + } + //} + + bool IsNonContiguous = CombinedInfo.NonContigInfo.IsNonContiguous; + // Maybe this needs to be "number of indirections". We want something + // similar to topological sort, but simpler. + unsigned MemberDepth = 0; + + MemberExpr *FirstMemberExpr = nullptr; + MemberExpr *LastMemberExpr = nullptr; + + for (; I != CE; ++I) { + StructRangeInfoTy *PartialStruct = nullptr; + bool Indirected = false; + + //auto Next = std::next(I); + if (auto ME = dyn_cast<MemberExpr>(I->getAssociatedExpression()->IgnoreParenImpCasts())) { + /*if (!FirstMemberExpr) { + QualType Ty = CI->getAssociatedDeclaration()->getType().getNonReferenceType(); + if (Ty->isAnyPointerType()) { + BP = CGF.EmitLoadOfPointer(BP, Ty->castAs<PointerType>()); + Indirected = true; + } + FirstMemberExpr = ME; + }*/ + ++MemberDepth; + } + + // Peek at the next outer expression to see if it's a "." member access + if (auto ME = dyn_cast<MemberExpr>(I->getAssociatedExpression()->IgnoreParenImpCasts())) { + auto Next = std::next(I); + if (Next != CE) { + if (auto NME = dyn_cast<MemberExpr>(Next->getAssociatedExpression()->IgnoreParenImpCasts())) { + if (!NME->isArrow()) + continue; + } + } + ... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/153672 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits