================
@@ -2592,6 +2595,281 @@ static RValue
EmitHipStdParUnsupportedBuiltin(CodeGenFunction *CGF,
return RValue::get(CGF->Builder.CreateCall(UBF, Args));
}
+namespace {
+
+// PaddingClearer is a utility class that clears padding bits in a
+// c++ type. It traverses the type recursively, collecting occupied
+// bit intervals, and then compute the padding intervals.
+// In the end, it clears the padding bits by writing zeros
+// to the padding intervals bytes-by-bytes. If a byte only contains
+// some padding bits, it writes zeros to only those bits. This is
+// the case for bit-fields.
+struct PaddingClearer {
+ PaddingClearer(CodeGenFunction &F)
+ : CGF(F), CharWidth(CGF.getContext().getCharWidth()) {}
+
+ void run(Value *Ptr, QualType Ty) {
+ OccuppiedIntervals.clear();
+ Queue.clear();
+
+ Queue.push_back(Data{0, Ty, true});
+ while (!Queue.empty()) {
+ auto Current = Queue.back();
+ Queue.pop_back();
+ Visit(Current);
+ }
+
+ MergeOccuppiedIntervals();
+ auto PaddingIntervals =
+ GetPaddingIntervals(CGF.getContext().getTypeSize(Ty));
+ for (const auto &Interval : PaddingIntervals) {
+ ClearPadding(Ptr, Interval);
+ }
+ }
+
+private:
+ struct BitInterval {
+ // [First, Last)
+ uint64_t First;
+ uint64_t Last;
+ };
+
+ struct Data {
+ uint64_t StartBitOffset;
+ QualType Ty;
+ bool VisitVirtualBase;
+ };
+
+ void Visit(Data const &D) {
+ if (auto *AT = dyn_cast<ConstantArrayType>(D.Ty)) {
+ VisitArray(AT, D.StartBitOffset);
+ return;
+ }
+
+ if (auto *Record = D.Ty->getAsCXXRecordDecl()) {
+ VisitStruct(Record, D.StartBitOffset, D.VisitVirtualBase);
+ return;
+ }
+
+ if (D.Ty->isAtomicType()) {
+ auto Unwrapped = D;
+ Unwrapped.Ty = D.Ty.getAtomicUnqualifiedType();
+ Queue.push_back(Unwrapped);
+ return;
+ }
+
+ if (const auto *Complex = D.Ty->getAs<ComplexType>()) {
+ VisitComplex(Complex, D.StartBitOffset);
+ return;
+ }
+
+ auto *Type = CGF.ConvertTypeForMem(D.Ty);
+ auto SizeBit = CGF.CGM.getModule()
+ .getDataLayout()
+ .getTypeSizeInBits(Type)
+ .getKnownMinValue();
+ OccuppiedIntervals.push_back(
+ BitInterval{D.StartBitOffset, D.StartBitOffset + SizeBit});
+ }
+
+ void VisitArray(const ConstantArrayType *AT, uint64_t StartBitOffset) {
+ for (uint64_t ArrIndex = 0; ArrIndex < AT->getSize().getLimitedValue();
+ ++ArrIndex) {
+
+ QualType ElementQualType = AT->getElementType();
+ auto ElementSize = CGF.getContext().getTypeSizeInChars(ElementQualType);
+ auto ElementAlign =
CGF.getContext().getTypeAlignInChars(ElementQualType);
+ auto Offset = ElementSize.alignTo(ElementAlign);
+
+ Queue.push_back(
+ Data{StartBitOffset + ArrIndex * Offset.getQuantity() * CharWidth,
+ ElementQualType, /*VisitVirtualBase*/ true});
+ }
+ }
+
+ void VisitStruct(const CXXRecordDecl *R, uint64_t StartBitOffset,
+ bool VisitVirtualBase) {
+ const auto &DL = CGF.CGM.getModule().getDataLayout();
+
+ const ASTRecordLayout &ASTLayout = CGF.getContext().getASTRecordLayout(R);
+ if (ASTLayout.hasOwnVFPtr()) {
+ OccuppiedIntervals.push_back(BitInterval{
+ StartBitOffset, StartBitOffset + DL.getPointerSizeInBits()});
+ }
+
+ const auto VisitBase = [&ASTLayout, StartBitOffset, this](
+ const CXXBaseSpecifier &Base, auto GetOffset) {
+ auto *BaseRecord = Base.getType()->getAsCXXRecordDecl();
+ if (!BaseRecord) {
+ return;
+ }
+ auto BaseOffset =
+ std::invoke(GetOffset, ASTLayout, BaseRecord).getQuantity();
+
+ Queue.push_back(Data{StartBitOffset + BaseOffset * CharWidth,
+ Base.getType(), /*VisitVirtualBase*/ false});
+ };
+
+ for (auto Base : R->bases()) {
+ if (!Base.isVirtual()) {
+ VisitBase(Base, &ASTRecordLayout::getBaseClassOffset);
+ }
+ }
+
+ if (VisitVirtualBase) {
+ for (auto VBase : R->vbases()) {
+ VisitBase(VBase, &ASTRecordLayout::getVBaseClassOffset);
+ }
+ }
+
+ for (auto *Field : R->fields()) {
+ auto FieldOffset = ASTLayout.getFieldOffset(Field->getFieldIndex());
+ if (Field->isBitField()) {
+ OccuppiedIntervals.push_back(BitInterval{
+ StartBitOffset + FieldOffset,
+ StartBitOffset + FieldOffset + Field->getBitWidthValue()});
+ } else {
+ Queue.push_back(Data{StartBitOffset + FieldOffset, Field->getType(),
+ /*VisitVirtualBase*/ true});
+ }
+ }
+ }
+
+ void VisitComplex(const ComplexType *CT, uint64_t StartBitOffset) {
+ QualType ElementQualType = CT->getElementType();
+ auto ElementSize = CGF.getContext().getTypeSizeInChars(ElementQualType);
+ auto ElementAlign = CGF.getContext().getTypeAlignInChars(ElementQualType);
+ auto ImgOffset = ElementSize.alignTo(ElementAlign);
+
+ Queue.push_back(
+ Data{StartBitOffset, ElementQualType, /*VisitVirtualBase*/ true});
+ Queue.push_back(Data{StartBitOffset + ImgOffset.getQuantity() * CharWidth,
+ ElementQualType, /*VisitVirtualBase*/ true});
+ }
+
+ void MergeOccuppiedIntervals() {
+ std::sort(OccuppiedIntervals.begin(), OccuppiedIntervals.end(),
+ [](const BitInterval &lhs, const BitInterval &rhs) {
+ return std::tie(lhs.First, lhs.Last) <
+ std::tie(rhs.First, rhs.Last);
+ });
+
+ std::vector<BitInterval> Merged;
+ Merged.reserve(OccuppiedIntervals.size());
+
+ for (const BitInterval &NextInterval : OccuppiedIntervals) {
+ if (Merged.empty()) {
+ Merged.push_back(NextInterval);
+ continue;
+ }
+ auto &LastInterval = Merged.back();
+
+ if (NextInterval.First > LastInterval.Last) {
+ Merged.push_back(NextInterval);
+ } else {
+ LastInterval.Last = std::max(LastInterval.Last, NextInterval.Last);
+ }
+ }
+
+ OccuppiedIntervals = Merged;
+ }
+
+ std::vector<BitInterval> GetPaddingIntervals(uint64_t SizeInBits) const {
+ std::vector<BitInterval> Results;
+ if (OccuppiedIntervals.size() == 1 &&
+ OccuppiedIntervals.front().First == 0 &&
+ OccuppiedIntervals.end()->Last == SizeInBits) {
+ return Results;
+ }
+ Results.reserve(OccuppiedIntervals.size() + 1);
+ uint64_t CurrentPos = 0;
+ for (const BitInterval &OccupiedInterval : OccuppiedIntervals) {
+ if (OccupiedInterval.First > CurrentPos) {
+ Results.push_back(BitInterval{CurrentPos, OccupiedInterval.First});
+ }
+ CurrentPos = OccupiedInterval.Last;
+ }
+ if (SizeInBits > CurrentPos) {
+ Results.push_back(BitInterval{CurrentPos, SizeInBits});
+ }
+ return Results;
+ }
+
+ void ClearPadding(Value *Ptr, const BitInterval &PaddingInterval) {
+ auto *I8Ptr = CGF.Builder.CreateBitCast(Ptr, CGF.Int8PtrTy);
+ auto *Zero = ConstantInt::get(CGF.Int8Ty, 0);
+
+ // Calculate byte indices and bit positions
+ auto StartByte = PaddingInterval.First / CharWidth;
+ auto StartBit = PaddingInterval.First % CharWidth;
+ auto EndByte = PaddingInterval.Last / CharWidth;
+ auto EndBit = PaddingInterval.Last % CharWidth;
+
+ if (StartByte == EndByte) {
+ // Interval is within a single byte
+ auto *Index = ConstantInt::get(CGF.IntTy, StartByte);
+ auto *Element = CGF.Builder.CreateGEP(CGF.Int8Ty, I8Ptr, Index);
+ Address ElementAddr(Element, CGF.Int8Ty, CharUnits::One());
+
+ auto *Value = CGF.Builder.CreateLoad(ElementAddr);
+
+ // Create mask to clear bits within the byte
+ uint8_t mask = ((1 << EndBit) - 1) & ~((1 << StartBit) - 1);
+ auto *MaskValue = ConstantInt::get(CGF.Int8Ty, mask);
+ auto *NewValue = CGF.Builder.CreateAnd(Value, MaskValue);
+
+ CGF.Builder.CreateStore(NewValue, ElementAddr);
+ } else {
+ // Handle the start byte
+ if (StartBit != 0) {
+ auto *Index = ConstantInt::get(CGF.IntTy, StartByte);
+ auto *Element = CGF.Builder.CreateGEP(CGF.Int8Ty, I8Ptr, Index);
+ Address ElementAddr(Element, CGF.Int8Ty, CharUnits::One());
+
+ auto *Value = CGF.Builder.CreateLoad(ElementAddr);
+
+ uint8_t startMask = ((1 << (CharWidth - StartBit)) - 1) << StartBit;
+ auto *MaskValue = ConstantInt::get(CGF.Int8Ty, ~startMask);
+ auto *NewValue = CGF.Builder.CreateAnd(Value, MaskValue);
+
+ CGF.Builder.CreateStore(NewValue, ElementAddr);
+ ++StartByte;
+ }
+
+ // Handle full bytes in the middle
+ for (auto Offset = StartByte; Offset < EndByte; ++Offset) {
+ auto *Index = ConstantInt::get(CGF.IntTy, Offset);
+ auto *Element = CGF.Builder.CreateGEP(CGF.Int8Ty, I8Ptr, Index);
+ Address ElementAddr(Element, CGF.Int8Ty, CharUnits::One());
+
+ CGF.Builder.CreateStore(Zero, ElementAddr);
+ }
+
+ // Handle the end byte
+ if (EndBit != 0) {
+ auto *Index = ConstantInt::get(CGF.IntTy, EndByte);
+ auto *Element = CGF.Builder.CreateGEP(CGF.Int8Ty, I8Ptr, Index);
+ Address ElementAddr(Element, CGF.Int8Ty, CharUnits::One());
+
+ auto *Value = CGF.Builder.CreateLoad(ElementAddr);
+
+ uint8_t endMask = (1 << EndBit) - 1;
+ auto *MaskValue = ConstantInt::get(CGF.Int8Ty, endMask);
+ auto *NewValue = CGF.Builder.CreateAnd(Value, MaskValue);
+
+ CGF.Builder.CreateStore(NewValue, ElementAddr);
+ }
+ }
+ }
+
+ CodeGenFunction &CGF;
+ const uint64_t CharWidth;
+ std::deque<Data> Queue;
----------------
huixie90 wrote:
Very good suggestion! Indeed in practice, the `T` in `atomic<T>` is usually a
very flat struct or even builtin types so the vector is indeed usually small
https://github.com/llvm/llvm-project/pull/75371
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits