llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-hlsl

Author: Steven Perron (s-perron)

<details>
<summary>Changes</summary>

The vk::constant_id attribute is used to indicate that a global const variable
represents a specialization constant in SPIR-V. This PR adds this
attribute to clang.

The documetation for the attribute is 
[here](https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/SPIR-V.rst#specialization-constants).

The strategy is to to modify the initializer to get the value of a
specialize constant for a builtin defined in the SPIR-V backend.


---

Patch is 39.25 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/143544.diff


16 Files Affected:

- (modified) clang/include/clang/Basic/Attr.td (+8) 
- (modified) clang/include/clang/Basic/AttrDocs.td (+15) 
- (modified) clang/include/clang/Basic/Builtins.td (+13) 
- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+12) 
- (modified) clang/include/clang/Sema/SemaHLSL.h (+4-1) 
- (modified) clang/lib/Basic/Attributes.cpp (+2-1) 
- (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+72) 
- (modified) clang/lib/CodeGen/CodeGenFunction.h (+11) 
- (modified) clang/lib/Sema/SemaDecl.cpp (+14) 
- (modified) clang/lib/Sema/SemaDeclAttr.cpp (+3) 
- (modified) clang/lib/Sema/SemaHLSL.cpp (+119-1) 
- (added) clang/test/AST/HLSL/vk.spec-constant.usage.hlsl (+130) 
- (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.alignment.hlsl () 
- (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.hlsl () 
- (added) clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl (+210) 
- (added) clang/test/SemaHLSL/vk.spec-constant.error.hlsl (+37) 


``````````diff
diff --git a/clang/include/clang/Basic/Attr.td 
b/clang/include/clang/Basic/Attr.td
index f889e41c8699f..d3f39de6a3e85 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4993,6 +4993,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
   let Documentation = [HLSLVkExtBuiltinInputDocs];
 }
 
+def HLSLVkConstantId : InheritableAttr {
+  let Spellings = [CXX11<"vk", "constant_id">];
+  let Args = [IntArgument<"Id">];
+  let Subjects = SubjectList<[ExternalGlobalVar]>;
+  let LangOpts = [HLSL];
+  let Documentation = [VkConstantIdDocs];
+}
+
 def RandomizeLayout : InheritableAttr {
   let Spellings = [GCC<"randomize_layout">];
   let Subjects = SubjectList<[Record]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td 
b/clang/include/clang/Basic/AttrDocs.td
index ea3c43f38d9fe..b3eafb79c5d4a 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -8252,6 +8252,21 @@ and 
https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
   }];
 }
 
+def VkConstantIdDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``vk::constant_id`` attribute specify the id for a SPIR-V specialization
+constant. The attribute applies to const global scalar variables. The variable 
must be initialized with a C++11 constexpr.
+In SPIR-V, the
+variable will be replaced with an `OpSpecConstant` with the given id.
+The syntax is:
+
+.. code-block:: text
+
+  ``[[vk::constant_id(<Id>)]] const T Name = <Init>``
+}];
+}
+
 def RootSignatureDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Basic/Builtins.td 
b/clang/include/clang/Basic/Builtins.td
index 68cd3d790e78a..d65b3a5d2f447 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5065,6 +5065,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: 
LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void()";
 }
 
+class HLSLScalarTemplate
+    : Template<["bool", "char", "short", "int", "long long int",
+                "unsigned short", "unsigned int", "unsigned long long int",
+                "__fp16", "float", "double"],
+               ["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
+                "_uint", "_ulonglong", "_half", "_float", "_double"]>;
+
+def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
+  let Spellings = ["__builtin_get_spirv_spec_constant"];
+  let Attributes = [NoThrow, Const, Pure];
+  let Prototype = "T(unsigned int, T)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td 
b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 1f283b776a02c..23a490225dd19 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12919,6 +12919,18 @@ def err_spirv_enum_not_int : Error<
 def err_spirv_enum_not_valid : Error<
    "invalid value for %select{storage class}0 argument">;
 
+def err_specialization_const_lit_init
+    : Error<"variable with 'vk::constant_id' attribute cannot have an "
+            "initializer that is not a constexpr">;
+def err_specialization_const_missing_initializer
+    : Error<
+          "variable with 'vk::constant_id' attribute must have an 
initializer">;
+def err_specialization_const_missing_const
+    : Error<"variable with 'vk::constant_id' attribute must be const">;
+def err_specialization_const_is_not_int_or_float
+    : Error<"variable with 'vk::constant_id' attribute must be an enum, bool, "
+            "integer, or floating point value">;
+
 // errors of expect.with.probability
 def err_probability_not_constant_float : Error<
    "probability argument to __builtin_expect_with_probability must be constant 
"
diff --git a/clang/include/clang/Sema/SemaHLSL.h 
b/clang/include/clang/Sema/SemaHLSL.h
index 66d09f49680be..099d9c35684e8 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
   HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
                                       int Min, int Max, int Preferred,
                                       int SpelledArgsCount);
+  HLSLVkConstantIdAttr *
+  mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
   HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                                   llvm::Triple::EnvironmentType ShaderType);
   HLSLParamModifierAttr *
@@ -122,6 +124,7 @@ class SemaHLSL : public SemaBase {
   void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
   void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
   void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
+  void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
@@ -156,7 +159,7 @@ class SemaHLSL : public SemaBase {
   QualType getInoutParameterType(QualType Ty);
 
   bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);
-
+  bool handleInitialization(VarDecl *VDecl, Expr *&Init);
   void deduceAddressSpace(VarDecl *Decl);
 
 private:
diff --git a/clang/lib/Basic/Attributes.cpp b/clang/lib/Basic/Attributes.cpp
index 905046685934b..5f74365f0ed00 100644
--- a/clang/lib/Basic/Attributes.cpp
+++ b/clang/lib/Basic/Attributes.cpp
@@ -213,7 +213,8 @@ getScopeFromNormalizedScopeName(StringRef ScopeName) {
       .Case("vk", AttributeCommonInfo::Scope::VK)
       .Case("msvc", AttributeCommonInfo::Scope::MSVC)
       .Case("omp", AttributeCommonInfo::Scope::OMP)
-      .Case("riscv", AttributeCommonInfo::Scope::RISCV);
+      .Case("riscv", AttributeCommonInfo::Scope::RISCV)
+      .Case("vk", AttributeCommonInfo::Scope::HLSL);
 }
 
 unsigned AttributeCommonInfo::calculateAttributeSpellingListIndex() const {
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index abebc201808b0..74b5f92638600 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -12,6 +12,7 @@
 
 #include "CGBuiltin.h"
 #include "CGHLSLRuntime.h"
+#include "CodeGenFunction.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -774,6 +775,77 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned 
BuiltinID,
     return EmitRuntimeCall(
         Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
   }
+  case Builtin::BI__builtin_get_spirv_spec_constant_bool:
+  case Builtin::BI__builtin_get_spirv_spec_constant_short:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
+  case Builtin::BI__builtin_get_spirv_spec_constant_int:
+  case Builtin::BI__builtin_get_spirv_spec_constant_uint:
+  case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_half:
+  case Builtin::BI__builtin_get_spirv_spec_constant_float:
+  case Builtin::BI__builtin_get_spirv_spec_constant_double: {
+    llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
+    llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
+    llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
+    llvm::Value *Args[] = {SpecId, DefaultVal};
+    return Builder.CreateCall(SpecConstantFn, Args);
+  }
   }
   return nullptr;
 }
+
+llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
+    const clang::QualType &SpecConstantType) {
+
+  // Find or create the declaration for the function.
+  llvm::Module *M = &CGM.getModule();
+  std::string MangledName = getSpecConstantFunctionName(SpecConstantType);
+  llvm::Function *SpecConstantFn = M->getFunction(MangledName);
+
+  if (!SpecConstantFn) {
+    llvm::Type *IntType = ConvertType(getContext().IntTy);
+    llvm::Type *RetTy = ConvertType(SpecConstantType);
+    llvm::Type *ArgTypes[] = {IntType, RetTy};
+    llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
+    SpecConstantFn = llvm::Function::Create(
+        FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
+  }
+  return SpecConstantFn;
+}
+
+std::string clang::CodeGen::CodeGenFunction::getSpecConstantFunctionName(
+    const clang::QualType &SpecConstantType) {
+  // The parameter types for our conceptual intrinsic function.
+  ASTContext &Context = getContext();
+  QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};
+
+  // Create a temporary FunctionDecl for the builtin fuction. It won't be
+  // added to the AST.
+  FunctionProtoType::ExtProtoInfo EPI;
+  QualType FnType =
+      Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
+  DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
+  FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
+      Context, Context.getTranslationUnitDecl(), SourceLocation(),
+      SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);
+
+  // Attach the created parameter declarations to the function declaration.
+  SmallVector<ParmVarDecl *, 2> ParamDecls;
+  for (QualType ParamType : ClangParamTypes) {
+    ParmVarDecl *PD = ParmVarDecl::Create(
+        Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
+        /*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
+        /*DefaultArg*/ nullptr);
+    ParamDecls.push_back(PD);
+  }
+  FnDeclForMangling->setParams(ParamDecls);
+
+  // Get the mangled name.
+  std::string Name;
+  llvm::raw_string_ostream MangledNameStream(Name);
+  MangleContext *Mangler = Context.createMangleContext();
+  Mangler->mangleName(FnDeclForMangling, MangledNameStream);
+  MangledNameStream.flush();
+  return Name;
+}
diff --git a/clang/lib/CodeGen/CodeGenFunction.h 
b/clang/lib/CodeGen/CodeGenFunction.h
index a5ab9df01dba9..badc9feeb8c2b 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4850,6 +4850,17 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
                                    ReturnValueSlot ReturnValue);
+
+  // Returns a builtin function that the SPIR-V backend will expand into a spec
+  // constant.
+  llvm::Function *
+  getSpecConstantFunction(const clang::QualType &SpecConstantType);
+
+  // Returns the mangled name for a builtin function that the SPIR-V backend
+  // will expand into a spec Constant.
+  std::string
+  getSpecConstantFunctionName(const clang::QualType &SpecConstantType);
+
   llvm::Value *EmitDirectXBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitSPIRVBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index bbd63372c168b..9d09eec26f7a2 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2889,6 +2889,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
     NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
                                          WS->getPreferred(),
                                          WS->getSpelledArgsCount());
+  else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
+    NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
     NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))
@@ -13755,6 +13757,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr 
*Init, bool DirectInit) {
     return;
   }
 
+  if (getLangOpts().HLSL)
+    if (!HLSL().handleInitialization(VDecl, Init))
+      return;
+
   // Get the decls type and save a reference for later, since
   // CheckInitializerTypes may change it.
   QualType DclT = VDecl->getType(), SavT = DclT;
@@ -14215,6 +14221,14 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
       }
     }
 
+    // HLSL variable with the `vk::constant_id` attribute must be initialized.
+    if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
+      Diag(Var->getLocation(),
+           diag::err_specialization_const_missing_initializer);
+      Var->setInvalidDecl();
+      return;
+    }
+
     if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) 
{
       if (Var->getStorageClass() == SC_Extern) {
         Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index da0e3265767d8..e49bdda1a402b 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7560,6 +7560,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, 
const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLVkExtBuiltinInput:
     S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
     break;
+  case ParsedAttr::AT_HLSLVkConstantId:
+    S.HLSL().handleVkConstantIdAttr(D, AL);
+    break;
   case ParsedAttr::AT_HLSLSV_GroupThreadID:
     S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9065cc5a1d4a5..5764507f35882 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -119,6 +119,40 @@ static ResourceClass getResourceClass(RegisterType RT) {
   llvm_unreachable("unexpected RegisterType value");
 }
 
+static Builtin::ID getSpecConstBuiltinId(QualType Type) {
+  const auto *BT = dyn_cast<BuiltinType>(Type);
+  if (!BT) {
+    if (!Type->isEnumeralType())
+      return Builtin::NotBuiltin;
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  }
+
+  switch (BT->getKind()) {
+  case BuiltinType::Bool:
+    return Builtin::BI__builtin_get_spirv_spec_constant_bool;
+  case BuiltinType::Short:
+    return Builtin::BI__builtin_get_spirv_spec_constant_short;
+  case BuiltinType::Int:
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  case BuiltinType::LongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
+  case BuiltinType::UShort:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
+  case BuiltinType::UInt:
+    return Builtin::BI__builtin_get_spirv_spec_constant_uint;
+  case BuiltinType::ULongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
+  case BuiltinType::Half:
+    return Builtin::BI__builtin_get_spirv_spec_constant_half;
+  case BuiltinType::Float:
+    return Builtin::BI__builtin_get_spirv_spec_constant_float;
+  case BuiltinType::Double:
+    return Builtin::BI__builtin_get_spirv_spec_constant_double;
+  default:
+    return Builtin::NotBuiltin;
+  }
+}
+
 DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
                                                       ResourceClass ResClass) {
   assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
@@ -607,6 +641,41 @@ HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
   return Result;
 }
 
+HLSLVkConstantIdAttr *
+SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
+                                int Id) {
+
+  auto &TargetInfo = getASTContext().getTargetInfo();
+  if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
+    Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
+    return nullptr;
+  }
+
+  auto *VD = cast<VarDecl>(D);
+
+  if (getSpecConstBuiltinId(VD->getType()) == Builtin::NotBuiltin) {
+    Diag(VD->getLocation(), 
diag::err_specialization_const_is_not_int_or_float);
+    return nullptr;
+  }
+
+  if (!VD->getType().isConstQualified()) {
+    Diag(VD->getLocation(), diag::err_specialization_const_missing_const);
+    return nullptr;
+  }
+
+  if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
+    if (CI->getId() != Id) {
+      Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+
+  HLSLVkConstantIdAttr *Result =
+      ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
+  return Result;
+}
+
 HLSLShaderAttr *
 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                           llvm::Triple::EnvironmentType ShaderType) {
@@ -1125,6 +1194,15 @@ void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, 
const ParsedAttr &AL) {
                  HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
 }
 
+void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
+  uint32_t Id;
+  if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
+    return;
+  HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
+  if (NewAttr)
+    D->addAttr(NewAttr);
+}
+
 bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
   const auto *VT = T->getAs<VectorType>();
 
@@ -3154,6 +3232,7 @@ static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
   return VD->getDeclContext()->isTranslationUnit() &&
          QT.getAddressSpace() == LangAS::Default &&
          VD->getStorageClass() != SC_Static &&
+         !VD->hasAttr<HLSLVkConstantIdAttr>() &&
          !isInvalidConstantBufferLeafElementType(QT.getTypePtr());
 }
 
@@ -3221,7 +3300,8 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
     const Type *VarType = VD->getType().getTypePtr();
     while (VarType->isArrayType())
       VarType = VarType->getArrayElementTypeNoTypeQual();
-    if (VarType->isHLSLResourceRecord()) {
+    if (VarType->isHLSLResourceRecord() ||
+        VD->hasAttr<HLSLVkConstantIdAttr>()) {
       // Make the variable for resources static. The global externally visible
       // storage is accessed through the handle, which is a member. The 
variable
       // itself is not externally visible.
@@ -3644,3 +3724,41 @@ bool SemaHLSL::transformInitList(const InitializedEntity 
&Entity,
     Init->updateInit(Ctx, I, NewInit->getInit(I));
   return true;
 }
+
+bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
+  const HLSLVkConstantIdAttr *ConstIdAttr =
+      VDecl->getAttr<HLSLVkConstantIdAttr>();
+  if (!ConstIdAttr)
+    return true;
+
+  ASTContext &Context = SemaRef.getASTContext();
+
+  APValue InitValue;
+  if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
+    Diag(VDecl->getLocation(), diag::err_specialization_const_lit_init);
+    VDecl->setInvalidDecl();
+    return false;
+  }
+
+  Builtin::ID BID = getSpecConstBuiltinId(VDecl->getType());
+
+  // Argument 1: The ID from the attribute
+  int ConstantID = ConstIdAttr->getId();
+  llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
+  Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
+                                        ConstIdAttr->getLocation());
+
+  SmallVector<Expr *, 2> Args = {IdExpr, Init};
+  Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
+  if (C->getType()->getCanonicalTypeUnqualified() !=
+      VDecl->getType()->getCanonicalTypeUnqualified()) {
+    C = SemaRef
+            .BuildCStyleCastExpr(SourceLocation(),
+                                 Context.getTrivialTypeSourceInfo(
+                                     Init->getType(), Init->getExprLoc()),
+                                 SourceLocation(), C)
+            .get();
+  }
+  Init = C;
+  return true;
+}
diff --git a/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl 
b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
new file mode 100644
index 0000000000000..c0955c1ea7b43
--- /dev/null
+++ b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
@@ -0,0 +1,130 @@
+// RUN: %clang_cc1 -finclude-default-header -triple 
spirv-unknown-vulkan-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// CHECK: VarDecl {{.*}} bool_const 'const hlsl_private bool' static cinit
+// CHECK-NEXT: CallExpr {{.*}} 'bool'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'bool (*)(unsigned int, bool) noexcept' 
<FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'bool (unsigned int, bool) noexcept' lvalue 
Function {{.*}} '__builtin_get_spirv_spec_constant_bool' 'bool (unsigned int, 
bool) noexcept'
+// CHECK-NEXT: ImplicitCastExpr {{.*}...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/143544
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to