https://github.com/Keenuts created 
https://github.com/llvm/llvm-project/pull/166793

TODO:
  - get the wg-hlsl proposal accepted.
  - fill this description.
  - add tests

From 8576183000f0f69a21bb505d891751cfa26c95b5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <[email protected]>
Date: Wed, 5 Nov 2025 16:37:37 +0100
Subject: [PATCH] [HLSL][SPIR-V] Implement vk::push_constant

TODO:
  - get the wg-hlsl proposal accepted.
  - fill this description.
  - add tests
---
 clang/include/clang/Basic/AddressSpaces.h    |  1 +
 clang/include/clang/Basic/Attr.td            |  8 ++++++++
 clang/include/clang/Basic/AttrDocs.td        |  5 +++++
 clang/include/clang/Sema/SemaHLSL.h          |  1 +
 clang/lib/AST/Type.cpp                       |  1 +
 clang/lib/AST/TypePrinter.cpp                |  2 ++
 clang/lib/Basic/TargetInfo.cpp               |  1 +
 clang/lib/Basic/Targets/AArch64.h            |  1 +
 clang/lib/Basic/Targets/AMDGPU.cpp           |  2 ++
 clang/lib/Basic/Targets/DirectX.h            |  1 +
 clang/lib/Basic/Targets/NVPTX.h              |  1 +
 clang/lib/Basic/Targets/SPIR.h               |  2 ++
 clang/lib/Basic/Targets/SystemZ.h            |  1 +
 clang/lib/Basic/Targets/TCE.h                |  1 +
 clang/lib/Basic/Targets/WebAssembly.h        |  1 +
 clang/lib/Basic/Targets/X86.h                |  1 +
 clang/lib/CodeGen/CodeGenModule.cpp          |  8 ++++++--
 clang/lib/Sema/SemaDecl.cpp                  |  3 +++
 clang/lib/Sema/SemaDeclAttr.cpp              |  3 +++
 clang/lib/Sema/SemaHLSL.cpp                  | 13 +++++++++++++
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 17 +++++++++--------
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp         |  2 ++
 llvm/lib/Target/SPIRV/SPIRVUtils.h           |  2 ++
 23 files changed, 68 insertions(+), 10 deletions(-)

diff --git a/clang/include/clang/Basic/AddressSpaces.h 
b/clang/include/clang/Basic/AddressSpaces.h
index 48e4a1c61fe02..7280b8fc923d2 100644
--- a/clang/include/clang/Basic/AddressSpaces.h
+++ b/clang/include/clang/Basic/AddressSpaces.h
@@ -62,6 +62,7 @@ enum class LangAS : unsigned {
   hlsl_private,
   hlsl_device,
   hlsl_input,
+  hlsl_push_constant,
 
   // Wasm specific address spaces.
   wasm_funcref,
diff --git a/clang/include/clang/Basic/Attr.td 
b/clang/include/clang/Basic/Attr.td
index 1013bfc575747..e8392edceeaf1 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -5154,6 +5154,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
   let Documentation = [HLSLVkExtBuiltinInputDocs];
 }
 
+def HLSLVkPushConstant : InheritableAttr {
+  let Spellings = [CXX11<"vk", "push_constant">];
+  let Args = [];
+  let Subjects = SubjectList<[GlobalVar], ErrorDiag>;
+  let LangOpts = [HLSL];
+  let Documentation = [HLSLVkPushConstantDocs];
+}
+
 def HLSLVkConstantId : InheritableAttr {
   let Spellings = [CXX11<"vk", "constant_id">];
   let Args = [IntArgument<"Id">];
diff --git a/clang/include/clang/Basic/AttrDocs.td 
b/clang/include/clang/Basic/AttrDocs.td
index 1be9a96aa44de..aeced5bc25fb5 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -8838,6 +8838,11 @@ 
https://github.com/microsoft/hlsl-specs/blob/main/proposals/0011-inline-spirv.md
   }];
 }
 
+def HLSLVkPushConstantDocs : Documentation {
+  let Category = DocCatVariable;
+  let Content = [{ FIXME }];
+}
+
 def AnnotateTypeDocs : Documentation {
   let Category = DocCatType;
   let Heading = "annotate_type";
diff --git a/clang/include/clang/Sema/SemaHLSL.h 
b/clang/include/clang/Sema/SemaHLSL.h
index 28b03ac4c4676..ed2d54469dde2 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -197,6 +197,7 @@ class SemaHLSL : public SemaBase {
   void handleSemanticAttr(Decl *D, const ParsedAttr &AL);
 
   void handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL);
+  void handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL);
 
   bool CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
   QualType ProcessResourceTypeAttributes(QualType Wrapped);
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 4548af17e37f2..53082bcf78f6a 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -101,6 +101,7 @@ bool Qualifiers::isTargetAddressSpaceSupersetOf(LangAS A, 
LangAS B,
          (A == LangAS::Default && B == LangAS::hlsl_private) ||
          (A == LangAS::Default && B == LangAS::hlsl_device) ||
          (A == LangAS::Default && B == LangAS::hlsl_input) ||
+         (A == LangAS::Default && B == LangAS::hlsl_push_constant) ||
          // Conversions from target specific address spaces may be legal
          // depending on the target information.
          Ctx.getTargetInfo().isAddressSpaceSupersetOf(A, B);
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index c18b2eafc722c..8448dd3748e28 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -2749,6 +2749,8 @@ std::string Qualifiers::getAddrSpaceAsString(LangAS AS) {
     return "hlsl_device";
   case LangAS::hlsl_input:
     return "hlsl_input";
+  case LangAS::hlsl_push_constant:
+    return "hlsl_push_constant";
   case LangAS::wasm_funcref:
     return "__funcref";
   default:
diff --git a/clang/lib/Basic/TargetInfo.cpp b/clang/lib/Basic/TargetInfo.cpp
index f4d7c1288cc04..4e34181144ad6 100644
--- a/clang/lib/Basic/TargetInfo.cpp
+++ b/clang/lib/Basic/TargetInfo.cpp
@@ -52,6 +52,7 @@ static const LangASMap FakeAddrSpaceMap = {
     15, // hlsl_private
     16, // hlsl_device
     17, // hlsl_input
+    18, // hlsl_push_constant
     20, // wasm_funcref
 };
 
diff --git a/clang/lib/Basic/Targets/AArch64.h 
b/clang/lib/Basic/Targets/AArch64.h
index 7d0737b2e8df0..22bb49dd6aea8 100644
--- a/clang/lib/Basic/Targets/AArch64.h
+++ b/clang/lib/Basic/Targets/AArch64.h
@@ -48,6 +48,7 @@ static const unsigned ARM64AddrSpaceMap[] = {
     0, // hlsl_private
     0, // hlsl_device
     0, // hlsl_input
+    0, // hlsl_push_constant
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/AMDGPU.cpp 
b/clang/lib/Basic/Targets/AMDGPU.cpp
index d4d696b8456b6..993a73a89c9e9 100644
--- a/clang/lib/Basic/Targets/AMDGPU.cpp
+++ b/clang/lib/Basic/Targets/AMDGPU.cpp
@@ -63,6 +63,7 @@ const LangASMap AMDGPUTargetInfo::AMDGPUDefIsGenMap = {
     llvm::AMDGPUAS::PRIVATE_ADDRESS, // hlsl_private
     llvm::AMDGPUAS::GLOBAL_ADDRESS,  // hlsl_device
     llvm::AMDGPUAS::PRIVATE_ADDRESS, // hlsl_input
+    llvm::AMDGPUAS::GLOBAL_ADDRESS,  // hlsl_push_constant
 };
 
 const LangASMap AMDGPUTargetInfo::AMDGPUDefIsPrivMap = {
@@ -91,6 +92,7 @@ const LangASMap AMDGPUTargetInfo::AMDGPUDefIsPrivMap = {
     llvm::AMDGPUAS::PRIVATE_ADDRESS,  // hlsl_private
     llvm::AMDGPUAS::GLOBAL_ADDRESS,   // hlsl_device
     llvm::AMDGPUAS::PRIVATE_ADDRESS,  // hlsl_input
+    llvm::AMDGPUAS::GLOBAL_ADDRESS,   // hlsl_push_constant
 };
 } // namespace targets
 } // namespace clang
diff --git a/clang/lib/Basic/Targets/DirectX.h 
b/clang/lib/Basic/Targets/DirectX.h
index a21a593365773..c0799a6f7610f 100644
--- a/clang/lib/Basic/Targets/DirectX.h
+++ b/clang/lib/Basic/Targets/DirectX.h
@@ -46,6 +46,7 @@ static const unsigned DirectXAddrSpaceMap[] = {
     0, // hlsl_private
     0, // hlsl_device
     0, // hlsl_input
+    0, // hlsl_push_constant
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/NVPTX.h b/clang/lib/Basic/Targets/NVPTX.h
index f5c8396f398aa..6338a4f2f9036 100644
--- a/clang/lib/Basic/Targets/NVPTX.h
+++ b/clang/lib/Basic/Targets/NVPTX.h
@@ -50,6 +50,7 @@ static const unsigned NVPTXAddrSpaceMap[] = {
     0, // hlsl_private
     0, // hlsl_device
     0, // hlsl_input
+    0, // hlsl_push_constant
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/SPIR.h b/clang/lib/Basic/Targets/SPIR.h
index 22b2799518dd0..94449231efb94 100644
--- a/clang/lib/Basic/Targets/SPIR.h
+++ b/clang/lib/Basic/Targets/SPIR.h
@@ -51,6 +51,7 @@ static const unsigned SPIRDefIsPrivMap[] = {
     10, // hlsl_private
     11, // hlsl_device
     7,  // hlsl_input
+    13, // hlsl_push_constant
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
@@ -87,6 +88,7 @@ static const unsigned SPIRDefIsGenMap[] = {
     10, // hlsl_private
     11, // hlsl_device
     7,  // hlsl_input
+    13, // hlsl_push_constant
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/SystemZ.h 
b/clang/lib/Basic/Targets/SystemZ.h
index 4e15d5af1cde6..4ce515b31a001 100644
--- a/clang/lib/Basic/Targets/SystemZ.h
+++ b/clang/lib/Basic/Targets/SystemZ.h
@@ -46,6 +46,7 @@ static const unsigned ZOSAddressMap[] = {
     0, // hlsl_private
     0, // hlsl_device
     0, // hlsl_input
+    0, // hlsl_push_constant
     0  // wasm_funcref
 };
 
diff --git a/clang/lib/Basic/Targets/TCE.h b/clang/lib/Basic/Targets/TCE.h
index 005cab9819472..161025378c471 100644
--- a/clang/lib/Basic/Targets/TCE.h
+++ b/clang/lib/Basic/Targets/TCE.h
@@ -55,6 +55,7 @@ static const unsigned TCEOpenCLAddrSpaceMap[] = {
     0, // hlsl_private
     0, // hlsl_device
     0, // hlsl_input
+    0, // hlsl_push_constant
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/WebAssembly.h 
b/clang/lib/Basic/Targets/WebAssembly.h
index 4de6ce6bb5a21..c8065843aeb42 100644
--- a/clang/lib/Basic/Targets/WebAssembly.h
+++ b/clang/lib/Basic/Targets/WebAssembly.h
@@ -46,6 +46,7 @@ static const unsigned WebAssemblyAddrSpaceMap[] = {
     0,  // hlsl_private
     0,  // hlsl_device
     0,  // hlsl_input
+    0,  // hlsl_push_constant
     20, // wasm_funcref
 };
 
diff --git a/clang/lib/Basic/Targets/X86.h b/clang/lib/Basic/Targets/X86.h
index e7da2622e78b5..7b88ac70e234f 100644
--- a/clang/lib/Basic/Targets/X86.h
+++ b/clang/lib/Basic/Targets/X86.h
@@ -50,6 +50,7 @@ static const unsigned X86AddrSpaceMap[] = {
     0,   // hlsl_private
     0,   // hlsl_device
     0,   // hlsl_input
+    0,   // hlsl_push_constant
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
diff --git a/clang/lib/CodeGen/CodeGenModule.cpp 
b/clang/lib/CodeGen/CodeGenModule.cpp
index 0fea57b2e1799..94d9a677661ac 100644
--- a/clang/lib/CodeGen/CodeGenModule.cpp
+++ b/clang/lib/CodeGen/CodeGenModule.cpp
@@ -6042,7 +6042,9 @@ void CodeGenModule::EmitGlobalVarDefinition(const VarDecl 
*D,
     getCUDARuntime().handleVarRegistration(D, *GV);
   }
 
-  if (LangOpts.HLSL && GetGlobalVarAddressSpace(D) == LangAS::hlsl_input) {
+  if (LangOpts.HLSL &&
+      (GetGlobalVarAddressSpace(D) == LangAS::hlsl_input ||
+       GetGlobalVarAddressSpace(D) == LangAS::hlsl_push_constant)) {
     // HLSL Input variables are considered to be set by the driver/pipeline, 
but
     // only visible to a single thread/wave.
     GV->setExternallyInitialized(true);
@@ -6098,7 +6100,9 @@ void CodeGenModule::EmitGlobalVarDefinition(const VarDecl 
*D,
   // HLSL variables in the input address space maps like memory-mapped
   // variables. Even if they are 'static', they are externally initialized and
   // read/write by the hardware/driver/pipeline.
-  if (LangOpts.HLSL && GetGlobalVarAddressSpace(D) == LangAS::hlsl_input)
+  if (LangOpts.HLSL &&
+      (GetGlobalVarAddressSpace(D) == LangAS::hlsl_input ||
+       GetGlobalVarAddressSpace(D) == LangAS::hlsl_push_constant))
     Linkage = llvm::GlobalValue::ExternalLinkage;
 
   GV->setLinkage(Linkage);
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index fc3aabf5741ca..6633c1a3cf226 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -14552,6 +14552,9 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
     if (getLangOpts().HLSL &&
         Var->getType().getAddressSpace() == LangAS::hlsl_input)
       return;
+    if (getLangOpts().HLSL &&
+        Var->getType().getAddressSpace() == LangAS::hlsl_push_constant)
+      return;
 
     // C++03 [dcl.init]p9:
     //   If no initializer is specified for an object, and the
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index a9e7b44ac9d73..0396155bd6a9d 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7614,6 +7614,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, 
const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLVkExtBuiltinInput:
     S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
     break;
+  case ParsedAttr::AT_HLSLVkPushConstant:
+    S.HLSL().handleVkPushConstantAttr(D, AL);
+    break;
   case ParsedAttr::AT_HLSLVkConstantId:
     S.HLSL().handleVkConstantIdAttr(D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a06c57b15c585..045f2559c32e5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1685,6 +1685,11 @@ void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, 
const ParsedAttr &AL) {
                  HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
 }
 
+void SemaHLSL::handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL) {
+  D->addAttr(::new (getASTContext())
+                 HLSLVkPushConstantAttr(getASTContext(), AL));
+}
+
 void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
   uint32_t Id;
   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
@@ -3846,6 +3851,7 @@ static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
          QT.getAddressSpace() == LangAS::Default &&
          VD->getStorageClass() != SC_Static &&
          !VD->hasAttr<HLSLVkConstantIdAttr>() &&
+         !VD->hasAttr<HLSLVkPushConstantAttr>() &&
          !isInvalidConstantBufferLeafElementType(QT.getTypePtr());
 }
 
@@ -3866,6 +3872,13 @@ void SemaHLSL::deduceAddressSpace(VarDecl *Decl) {
     return;
   }
 
+  if (Decl->hasAttr<HLSLVkPushConstantAttr>()) {
+    LangAS ImplAS = LangAS::hlsl_push_constant;
+    Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
+    Decl->setType(Type);
+    return;
+  }
+
   if (Type->isSamplerT() || Type->isVoidType())
     return;
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp 
b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 53074ea3b2597..38be5b450f998 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -87,14 +87,15 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget 
&ST) {
   const LLT p10 = LLT::pointer(10, PSize); // Private
   const LLT p11 = LLT::pointer(11, PSize); // StorageBuffer
   const LLT p12 = LLT::pointer(12, PSize); // Uniform
+  const LLT p13 = LLT::pointer(13, PSize); // PushConstant
 
   // TODO: remove copy-pasting here by using concatenation in some way.
   auto allPtrsScalarsAndVectors = {
-      p0,    p1,    p2,    p3,     p4,     p5,    p6,    p7,    p8,
-      p10,   p11,   p12,   s1,     s8,     s16,   s32,   s64,   v2s1,
-      v2s8,  v2s16, v2s32, v2s64,  v3s1,   v3s8,  v3s16, v3s32, v3s64,
-      v4s1,  v4s8,  v4s16, v4s32,  v4s64,  v8s1,  v8s8,  v8s16, v8s32,
-      v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
+      p0,    p1,    p2,    p3,    p4,     p5,     p6,    p7,    p8,
+      p10,   p11,   p12,   p13,   s1,     s8,     s16,   s32,   s64,
+      v2s1,  v2s8,  v2s16, v2s32, v2s64,  v3s1,   v3s8,  v3s16, v3s32,
+      v3s64, v4s1,  v4s8,  v4s16, v4s32,  v4s64,  v8s1,  v8s8,  v8s16,
+      v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
 
   auto allVectors = {v2s1,  v2s8,   v2s16,  v2s32, v2s64, v3s1,  v3s8,
                      v3s16, v3s32,  v3s64,  v4s1,  v4s8,  v4s16, v4s32,
@@ -121,10 +122,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const 
SPIRVSubtarget &ST) {
       s16,   s32,   s64,   v2s16, v2s32, v2s64, v3s16,  v3s32,  v3s64,
       v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
 
-  auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,  p2,  p3,
-                                       p4, p5,  p6,  p7,  p8, p10, p11, p12};
+  auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0,  p1,  p2,  p3, 
p4,
+                                       p5, p6,  p7,  p8,  p10, p11, p12, p13};
 
-  auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};
+  auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12, p13};
 
   bool IsExtendedInts =
       ST.canUseExtension(
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp 
b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 8f2fc01da476f..8202db2380a68 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -290,6 +290,8 @@ addressSpaceToStorageClass(unsigned AddrSpace, const 
SPIRVSubtarget &STI) {
     return SPIRV::StorageClass::StorageBuffer;
   case 12:
     return SPIRV::StorageClass::Uniform;
+  case 13:
+    return SPIRV::StorageClass::PushConstant;
   default:
     report_fatal_error("Unknown address space");
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h 
b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 99d9d403ea70c..32e744c18a5f7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -264,6 +264,8 @@ 
storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) {
     return 11;
   case SPIRV::StorageClass::Uniform:
     return 12;
+  case SPIRV::StorageClass::PushConstant:
+    return 13;
   default:
     report_fatal_error("Unable to get address space id");
   }

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to