llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Pppp1116 (Pppp1116)

<details>
<summary>Changes</summary>

Fixes #<!-- -->190299.

## What changed

This teaches HLSL codegen to materialize embedded resource members back into a 
global struct object before that object is used. That lets struct methods 
access resource fields like `this-&gt;Buf` through the ordinary member path, 
even when the resource was flattened into synthesized global declarations 
earlier in semantic analysis.

The change walks nested structs, arrays, and single-inheritance bases in the 
same order as the associated resource declarations and repopulates each 
embedded resource field from its synthesized global resource.

A new codegen test covers direct and nested-array method calls on structs with 
embedded resources.

## Why

Today Clang can flatten global struct resources and create the associated 
resource globals, but the original struct object's resource fields are still 
uninitialized at codegen time. As a result, method bodies that access those 
members end up reading the poisoned field storage instead of the synthesized 
resource handle.

## Impact

C/C++ HLSL users can now call struct methods that read or write embedded 
resources without manually rewriting those accesses at the call site.



---
Full diff: https://github.com/llvm/llvm-project/pull/190373.diff


4 Files Affected:

- (modified) clang/lib/CodeGen/CGExpr.cpp (+6-2) 
- (modified) clang/lib/CodeGen/CGHLSLRuntime.cpp (+141) 
- (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+2) 
- (added) clang/test/CodeGenHLSL/resources/resources-in-struct-methods.hlsl 
(+39) 


``````````diff
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 23802cdeb4811..4f2d8368e600a 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -3663,8 +3663,12 @@ LValue CodeGenFunction::EmitDeclRefLValue(const 
DeclRefExpr *E) {
 
   if (const auto *VD = dyn_cast<VarDecl>(ND)) {
     // Check if this is a global variable.
-    if (VD->hasLinkage() || VD->isStaticDataMember())
-      return EmitGlobalVarDeclLValue(*this, E, VD);
+    if (VD->hasLinkage() || VD->isStaticDataMember()) {
+      LValue LV = EmitGlobalVarDeclLValue(*this, E, VD);
+      if (getLangOpts().HLSL)
+        CGM.getHLSLRuntime().populateGlobalStructResources(*this, VD, LV);
+      return LV;
+    }
 
     Address addr = Address::invalid();
 
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp 
b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 6182663111f5a..ac6f286ad4888 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -32,6 +32,7 @@
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
@@ -116,6 +117,131 @@ static const ValueDecl *getArrayDecl(const 
ArraySubscriptExpr *ASE) {
   return getArrayDecl(E);
 }
 
+static DeclRefExpr *createVarDeclRefExpr(ASTContext &AST, const VarDecl *VD,
+                                         SourceLocation Loc) {
+  return DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), SourceLocation(),
+                             const_cast<VarDecl *>(VD),
+                             /*RefersToEnclosingVariableOrCapture=*/false, Loc,
+                             VD->getType(), VK_LValue);
+}
+
+static bool containsStructResources(QualType Ty) {
+  Ty = Ty.getCanonicalType();
+  const clang::Type *T = Ty.getTypePtr();
+  if (T->isHLSLResourceRecord() || T->isHLSLResourceRecordArray())
+    return true;
+
+  if (const auto *CAT = dyn_cast<ConstantArrayType>(T))
+    return containsStructResources(CAT->getElementType());
+
+  const auto *RD = Ty->getAsCXXRecordDecl();
+  if (!RD)
+    return false;
+
+  assert(RD->getNumBases() <= 1 && "HLSL doesn't support multiple 
inheritance");
+  if (auto Base = RD->bases_begin();
+      Base != RD->bases_end() && containsStructResources(Base->getType()))
+    return true;
+
+  return llvm::any_of(RD->fields(), [](const FieldDecl *FD) {
+    return containsStructResources(FD->getType());
+  });
+}
+
+class StructResourcePopulator {
+  CGHLSLRuntime &Runtime;
+  CodeGenFunction &CGF;
+  const VarDecl *RootDecl;
+  ArrayRef<const HLSLAssociatedResourceDeclAttr *> ResourceAttrs;
+  size_t ResourceIndex = 0;
+
+  ASTContext &getASTContext() const { return CGF.getContext(); }
+
+  SourceLocation getLoc() const { return RootDecl->getLocation(); }
+
+  const VarDecl *getNextResourceDecl() {
+    assert(ResourceIndex < ResourceAttrs.size() &&
+           "not enough associated resource declarations");
+    return ResourceAttrs[ResourceIndex++]->getResDecl();
+  }
+
+  LValue getArrayElementLValue(LValue Base, QualType ElementTy,
+                               uint64_t Index) {
+    Address ElemAddr = CGF.Builder.CreateConstArrayGEP(Base.getAddress(), 
Index);
+    return CGF.MakeAddrLValue(ElemAddr, ElementTy, Base.getBaseInfo(),
+                              CGF.CGM.getTBAAInfoForSubobject(Base, 
ElementTy));
+  }
+
+  LValue getDirectBaseLValue(LValue Base, const CXXRecordDecl *Derived,
+                             const CXXBaseSpecifier &BaseSpec) {
+    QualType BaseTy = BaseSpec.getType();
+    const auto *BaseDecl = BaseTy->getAsCXXRecordDecl();
+    Address BaseAddr = CGF.GetAddressOfDirectBaseInCompleteClass(
+        Base.getAddress(), Derived, BaseDecl, BaseSpec.isVirtual());
+    return CGF.MakeAddrLValue(BaseAddr, BaseTy, Base.getBaseInfo(),
+                              CGF.CGM.getTBAAInfoForSubobject(Base, BaseTy));
+  }
+
+  void populateResource(LValue Dest, QualType ResourceTy) {
+    const VarDecl *ResDecl = getNextResourceDecl();
+    if (ResourceTy->isHLSLResourceRecordArray()) {
+      Expr *ResRef = createVarDeclRefExpr(getASTContext(), ResDecl, getLoc());
+      bool Emitted = Runtime.emitResourceArrayCopy(Dest, ResRef, CGF);
+      assert(Emitted && "expected global associated resource array");
+      (void)Emitted;
+      return;
+    }
+
+    LValue Src =
+        CGF.EmitDeclRefLValue(createVarDeclRefExpr(getASTContext(), ResDecl,
+                                                   getLoc()));
+    CGF.EmitAggregateAssign(Dest, Src, ResourceTy);
+  }
+
+public:
+  StructResourcePopulator(
+      CGHLSLRuntime &Runtime, CodeGenFunction &CGF, const VarDecl *RootDecl,
+      ArrayRef<const HLSLAssociatedResourceDeclAttr *> ResourceAttrs)
+      : Runtime(Runtime), CGF(CGF), RootDecl(RootDecl),
+        ResourceAttrs(ResourceAttrs) {}
+
+  void populate(LValue Base, QualType Ty) {
+    Ty = Ty.getCanonicalType();
+    const clang::Type *T = Ty.getTypePtr();
+    if (T->isHLSLResourceRecord() || T->isHLSLResourceRecordArray()) {
+      populateResource(Base, Ty);
+      return;
+    }
+
+    if (const auto *CAT = dyn_cast<ConstantArrayType>(T)) {
+      QualType ElementTy = CAT->getElementType();
+      if (!containsStructResources(ElementTy))
+        return;
+
+      for (uint64_t I = 0, E = CAT->getSize().getZExtValue(); I < E; ++I)
+        populate(getArrayElementLValue(Base, ElementTy, I), ElementTy);
+      return;
+    }
+
+    const auto *RD = Ty->getAsCXXRecordDecl();
+    if (!RD)
+      return;
+
+    assert(RD->getNumBases() <= 1 && "HLSL doesn't support multiple 
inheritance");
+    const auto *BaseIt = RD->bases_begin();
+    if (BaseIt != RD->bases_end() && 
containsStructResources(BaseIt->getType()))
+      populate(getDirectBaseLValue(Base, RD, *BaseIt), BaseIt->getType());
+
+    for (const FieldDecl *FD : RD->fields()) {
+      if (!containsStructResources(FD->getType()))
+        continue;
+      populate(CGF.EmitLValueForField(Base, FD), FD->getType());
+    }
+  }
+
+  bool finished() const { return ResourceIndex == ResourceAttrs.size(); }
+};
+
 // Get the total size of the array, or 0 if the array is unbounded.
 static int getTotalArraySize(ASTContext &AST, const clang::Type *Ty) {
   Ty = Ty->getUnqualifiedDesugaredType();
@@ -1368,6 +1494,21 @@ bool CGHLSLRuntime::emitResourceArrayCopy(LValue &LHS, 
Expr *RHSExpr,
   return EndIndex.has_value();
 }
 
+void CGHLSLRuntime::populateGlobalStructResources(CodeGenFunction &CGF,
+                                                  const VarDecl *VD,
+                                                  LValue Root) {
+  llvm::SmallVector<const HLSLAssociatedResourceDeclAttr *, 8> ResourceAttrs;
+  for (const auto *Attr : VD->specific_attrs<HLSLAssociatedResourceDeclAttr>())
+    ResourceAttrs.push_back(Attr);
+  if (ResourceAttrs.empty())
+    return;
+
+  StructResourcePopulator Populator(*this, CGF, VD, ResourceAttrs);
+  Populator.populate(Root, VD->getType());
+  assert(Populator.finished() &&
+         "expected all associated resource declarations to be consumed");
+}
+
 RawAddress CGHLSLRuntime::createBufferMatrixTempAddress(const LValue &LV,
                                                         SourceLocation Loc,
                                                         CodeGenFunction &CGF) {
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h 
b/clang/lib/CodeGen/CGHLSLRuntime.h
index b1c5b3318a11e..8090e84614363 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -292,6 +292,8 @@ class CGHLSLRuntime {
   emitResourceArraySubscriptExpr(const ArraySubscriptExpr *E,
                                  CodeGenFunction &CGF);
   bool emitResourceArrayCopy(LValue &LHS, Expr *RHSExpr, CodeGenFunction &CGF);
+  void populateGlobalStructResources(CodeGenFunction &CGF, const VarDecl *VD,
+                                     LValue Root);
 
   std::optional<LValue> emitBufferArraySubscriptExpr(
       const ArraySubscriptExpr *E, CodeGenFunction &CGF,
diff --git a/clang/test/CodeGenHLSL/resources/resources-in-struct-methods.hlsl 
b/clang/test/CodeGenHLSL/resources/resources-in-struct-methods.hlsl
new file mode 100644
index 0000000000000..af3ed8b08f8ed
--- /dev/null
+++ b/clang/test/CodeGenHLSL/resources/resources-in-struct-methods.hlsl
@@ -0,0 +1,39 @@
+// RUN: %clang_cc1 -triple dxil--shadermodel6.6-compute -x hlsl 
-finclude-default-header \
+// RUN:   -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+
+struct BufferWrapper {
+  RWBuffer<float> Buf;
+
+  void Store(float Value) {
+    Buf[0] = Value;
+  }
+};
+
+struct Outer {
+  BufferWrapper Elements[2];
+};
+
+BufferWrapper One : register(u0);
+BufferWrapper Two : register(u1);
+Outer Nested : register(u2);
+
+[numthreads(1, 1, 1)]
+void main() {
+  One.Store(1.0);
+  Two.Store(2.0);
+  Nested.Elements[1].Store(3.0);
+}
+
+// CHECK-LABEL: define void @main()
+// CHECK: %[[ONE_FIELD:.*]] = getelementptr inbounds nuw 
%struct.BufferWrapper, ptr @One, i32 0, i32 0
+// CHECK: store {{.*}}, ptr %[[ONE_FIELD]]
+// CHECK: call void @{{.*Store.*}}(ptr noundef @One, float noundef 
1.000000e+00)
+// CHECK: %[[TWO_FIELD:.*]] = getelementptr inbounds nuw 
%struct.BufferWrapper, ptr @Two, i32 0, i32 0
+// CHECK: store {{.*}}, ptr %[[TWO_FIELD]]
+// CHECK: call void @{{.*Store.*}}(ptr noundef @Two, float noundef 
2.000000e+00)
+// CHECK: %[[NESTED_ELEM0_FIELD:.*]] = getelementptr inbounds nuw 
%struct.Outer, ptr @Nested, i32 0, i32 0, i64 0, i32 0
+// CHECK: store {{.*}}, ptr %[[NESTED_ELEM0_FIELD]]
+// CHECK: %[[NESTED_ELEM1_FIELD:.*]] = getelementptr inbounds nuw 
%struct.Outer, ptr @Nested, i32 0, i32 0, i64 1, i32 0
+// CHECK: store {{.*}}, ptr %[[NESTED_ELEM1_FIELD]]
+// CHECK: %[[NESTED_ELEM1:.*]] = getelementptr inbounds nuw %struct.Outer, ptr 
@Nested, i32 0, i32 0, i64 1
+// CHECK: call void @{{.*Store.*}}(ptr noundef %[[NESTED_ELEM1]], float 
noundef 3.000000e+00)

``````````

</details>


https://github.com/llvm/llvm-project/pull/190373
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to