llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Joshua Batista (bob80905)

<details>
<summary>Changes</summary>

This PR adds validation for register numbers.
Register numbers ought never to exceed UINT32_MAX, or 4294967295
Additionally, resource arrays will have each resource element bound 
sequentially, and those resource's register numbers should not exceed 
UINT32_MAX, or 4294967295. Even though not explicitly given a register number, 
their effective register number is also validated.
This accounts for nested resource declarations and resource arrays too.

Fixes https://github.com/llvm/llvm-project/issues/136809

---
Full diff: https://github.com/llvm/llvm-project/pull/174028.diff


3 Files Affected:

- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+1) 
- (modified) clang/lib/Sema/SemaHLSL.cpp (+102) 
- (added) clang/test/SemaHLSL/resource_binding_attr_error_uint32_max.hlsl (+47) 


``````````diff
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td 
b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 6c6a26614ad0e..3807d1dab728a 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -13336,6 +13336,7 @@ def warn_hlsl_register_type_c_packoffset: 
Warning<"binding type 'c' ignored in b
 def warn_hlsl_deprecated_register_type_b: Warning<"binding type 'b' only 
applies to constant buffers. The 'bool constant' binding type is no longer 
supported">, InGroup<LegacyConstantRegisterBinding>, DefaultError;
 def warn_hlsl_deprecated_register_type_i: Warning<"binding type 'i' ignored. 
The 'integer constant' binding type is no longer supported">, 
InGroup<LegacyConstantRegisterBinding>, DefaultError;
 def err_hlsl_unsupported_register_number : Error<"register number should be an 
integer">;
+def err_hlsl_register_number_too_large : Error<"register number should not 
exceed UINT32_MAX, 4294967295">;
 def err_hlsl_expected_space : Error<"invalid space specifier '%0' used; 
expected 'space' followed by an integer, like space1">;
 def err_hlsl_space_on_global_constant : Error<"register space cannot be 
specified on global constants">;
 def warn_hlsl_implicit_binding : Warning<"resource has implicit register 
binding">, InGroup<HLSLImplicitBinding>, DefaultIgnore;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a6de1cd550212..7760b1f791374 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2361,6 +2361,99 @@ static bool DiagnoseHLSLRegisterAttribute(Sema &S, 
SourceLocation &ArgLoc,
   return ValidateMultipleRegisterAnnotations(S, D, RegType);
 }
 
+bool ExceedsUInt32Max(llvm::StringRef S) {
+  constexpr size_t MaxDigits = 10; // UINT32_MAX = 4294967295
+  if (S.size() > MaxDigits)
+    return true;
+
+  if (S.size() < MaxDigits)
+    return false;
+
+  return S.compare("4294967295") > 0;
+}
+
+// return false if the slot count exceeds the limit, true otherwise
+static bool AccumulateHLSLResourceSlots(QualType Ty, llvm::APInt &SlotCount,
+                                        const llvm::APInt &Limit,
+                                        ASTContext &Ctx,
+                                        uint64_t Multiplier = 1) {
+  Ty = Ty.getCanonicalType();
+  const Type *T = Ty.getTypePtr();
+
+  // Early exit if already overflowed
+  if (SlotCount.ugt(Limit))
+    return false;
+
+  // Case 1: array type
+  if (const auto *AT = dyn_cast<ArrayType>(T)) {
+    uint64_t Count = 1;
+
+    if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
+      Count = CAT->getSize().getZExtValue();
+    }
+    // TODO: how do we handle non constant resource arrays?
+
+    QualType ElemTy = AT->getElementType();
+
+    return AccumulateHLSLResourceSlots(ElemTy, SlotCount, Limit, Ctx,
+                                       Multiplier * Count);
+  }
+
+  // Case 2: resource leaf
+  if (T->isHLSLResourceRecord()) {
+    llvm::APInt Add(SlotCount.getBitWidth(), Multiplier);
+    SlotCount += Add;
+    return SlotCount.ule(Limit);
+  }
+
+  // Case 3: struct / record
+  if (const auto *RT = dyn_cast<RecordType>(T)) {
+    const RecordDecl *RD = RT->getDecl();
+    for (const FieldDecl *Field : RD->fields()) {
+      if (!AccumulateHLSLResourceSlots(Field->getType(), SlotCount, Limit, Ctx,
+                                       Multiplier))
+        return false;
+    }
+    return true;
+  }
+
+  // Case 4: everything else
+  return true;
+}
+
+// return true if there is something invalid, false otherwise
+bool ValidateRegisterNumber(StringRef SlotNumStr, Decl *TheDecl,
+                            ASTContext &Ctx) {
+  if (ExceedsUInt32Max(SlotNumStr))
+    return true;
+
+  llvm::APInt SlotNum;
+  if (SlotNumStr.getAsInteger(10, SlotNum))
+    return false;
+  SlotNum = SlotNum.zext(64);
+
+  // uint32_max isn't 64 bits, but this int should
+  // have a 64 bit width in case it is compared to
+  // another 64 bit-width value. Assert failure otherwise.
+  llvm::APInt Limit(64, UINT32_MAX);
+  VarDecl *VD = dyn_cast<VarDecl>(TheDecl);
+  if (VD) {
+    AccumulateHLSLResourceSlots(VD->getType(), SlotNum, Limit, Ctx);
+    return SlotNum.ugt(Limit);
+  }
+  // handle the cbuffer case
+  HLSLBufferDecl *HBD = dyn_cast<HLSLBufferDecl>(TheDecl);
+  if (HBD) {
+    // resources cannot be put within a cbuffer, so no need
+    // to analyze the structure since the register number
+    // won't be pushed any higher.
+    return SlotNum.ugt(Limit);
+  }
+
+  // we don't expect any other decl type, so fail
+  return true;
+}
+
 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
   if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) {
     QualType Ty = VD->getType();
@@ -2420,6 +2513,15 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, 
const ParsedAttr &AL) {
       return;
     }
     StringRef SlotNumStr = Slot.substr(1);
+
+    // Validate register number. It should not exceed UINT32_MAX,
+    // including if the resource type is an array that starts
+    // before UINT32_MAX, but ends afterwards.
+    if (ValidateRegisterNumber(SlotNumStr, TheDecl, getASTContext())) {
+      Diag(SlotLoc, diag::err_hlsl_register_number_too_large);
+      return;
+    }
+
     unsigned N;
     if (SlotNumStr.getAsInteger(10, N)) {
       Diag(SlotLoc, diag::err_hlsl_unsupported_register_number);
diff --git a/clang/test/SemaHLSL/resource_binding_attr_error_uint32_max.hlsl 
b/clang/test/SemaHLSL/resource_binding_attr_error_uint32_max.hlsl
new file mode 100644
index 0000000000000..70a226337ef6a
--- /dev/null
+++ b/clang/test/SemaHLSL/resource_binding_attr_error_uint32_max.hlsl
@@ -0,0 +1,47 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -o - 
-fsyntax-only %s -verify
+
+// test semantic validation for register numbers that exceed UINT32_MAX
+
+struct S {
+  RWBuffer<float> A[4];
+  RWBuffer<int> B[10];
+};
+
+// do some more nesting
+struct S2 {
+  S a[3];
+};
+
+// test that S.A carries the register number over the limit and emits the error
+// expected-error@+1 {{register number should not exceed UINT32_MAX, 
4294967295}}
+S s : register(u4294967294); // UINT32_MAX - 1
+
+// test the error is also triggered when analyzing S.B
+// expected-error@+1 {{register number should not exceed UINT32_MAX, 
4294967295}}
+S s2 : register(u4294967289);
+
+
+// test the error is also triggered when analyzing S2.a[1].B
+// expected-error@+1 {{register number should not exceed UINT32_MAX, 
4294967295}}
+S2 s3 : register(u4294967275);
+
+// expected-error@+1 {{register number should not exceed UINT32_MAX, 
4294967295}}
+RWBuffer<float> Buf[10][10] : register(u4294967234);
+
+// test a standard resource array
+// expected-error@+1 {{register number should not exceed UINT32_MAX, 
4294967295}}
+RWBuffer<float> Buf2[10] : register(u4294967294); 
+
+// test directly an excessively high register number.
+// expected-error@+1 {{register number should not exceed UINT32_MAX, 
4294967295}}
+RWBuffer<float> A : register(u9995294967294);
+
+// test a struct within a cbuffer
+// expected-error@+1 {{register number should not exceed UINT32_MAX, 
4294967295}}
+cbuffer MyCB : register(b9995294967294) {
+  float F[4];
+  int   I[10];
+};
+
+// no errors expected, all 100 register numbers are occupied here
+RWBuffer<float> Buf3[10][10] : register(u4294967194); 

``````````

</details>


https://github.com/llvm/llvm-project/pull/174028
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to