Author: Chris Bieneman
Date: 2022-03-29T17:17:19-05:00
New Revision: 94189b42cc51b5fa8355957a976f0d8b4f8c312b

URL: 
https://github.com/llvm/llvm-project/commit/94189b42cc51b5fa8355957a976f0d8b4f8c312b
DIFF: 
https://github.com/llvm/llvm-project/commit/94189b42cc51b5fa8355957a976f0d8b4f8c312b.diff

LOG: [HLSL] Fix MSFT Attribute parsing, add numthreads

HLSL uses Microsoft-style attributes `[attr]`, which clang mostly
ignores. For HLSL we need to handle known Microsoft attributes, and to
maintain C/C++ as-is we ignore unknown attributes.

To utilize this new code path, this change adds the HLSL `numthreads`
attribute.

Reviewed By: rnk

Differential Revision: https://reviews.llvm.org/D122627

Added: 
    clang/test/SemaHLSL/lit.local.cfg
    clang/test/SemaHLSL/num_threads.hlsl

Modified: 
    clang/include/clang/Basic/Attr.td
    clang/include/clang/Basic/AttrDocs.td
    clang/include/clang/Basic/DiagnosticSemaKinds.td
    clang/include/clang/Parse/Parser.h
    clang/lib/Parse/ParseDeclCXX.cpp
    clang/lib/Sema/SemaDecl.cpp
    clang/lib/Sema/SemaDeclAttr.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/Attr.td 
b/clang/include/clang/Basic/Attr.td
index f7ef3bf42ec2f..408ea11388c07 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -336,6 +336,8 @@ def ObjCAutoRefCount : LangOpt<"ObjCAutoRefCount">;
 def ObjCNonFragileRuntime
     : LangOpt<"", "LangOpts.ObjCRuntime.allowsClassStubs()">;
 
+def HLSL : LangOpt<"HLSL">;
+
 // Language option for CMSE extensions
 def Cmse : LangOpt<"Cmse">;
 
@@ -3937,3 +3939,11 @@ def Error : InheritableAttr {
   let Subjects = SubjectList<[Function], ErrorDiag>;
   let Documentation = [ErrorAttrDocs];
 }
+
+def HLSLNumThreads: InheritableAttr {
+  let Spellings = [Microsoft<"numthreads">];
+  let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
+  let Subjects = SubjectList<[Function]>;
+  let LangOpts = [HLSL];
+  let Documentation = [NumThreadsDocs];
+}

diff  --git a/clang/include/clang/Basic/AttrDocs.td 
b/clang/include/clang/Basic/AttrDocs.td
index ebf0725a0a392..89db454f7dac4 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -6368,3 +6368,14 @@ flag.
 .. _Return-Oriented Programming: 
https://en.wikipedia.org/wiki/Return-oriented_programming
   }];
 }
+
+def NumThreadsDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``numthreads`` attribute applies to HLSL shaders where explcit thread 
counts
+are required. The ``X``, ``Y``, and ``Z`` values provided to the attribute
+dictate the thread id. Total number of threads executed is ``X * Y * Z``.
+
+The full documentation is available here: 
https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-attributes-numthreads
+  }];
+}

diff  --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td 
b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 172a10a65c8c0..d3055fed20828 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11564,4 +11564,12 @@ def err_std_source_location_impl_not_found : Error<
   "'std::source_location::__impl' was not found; it must be defined before 
'__builtin_source_location' is called">;
 def err_std_source_location_impl_malformed : Error<
   "'std::source_location::__impl' must be standard-layout and have only two 
'const char *' fields '_M_file_name' and '_M_function_name', and two integral 
fields '_M_line' and '_M_column'">;
+
+// HLSL Diagnostics
+def err_hlsl_attr_unsupported_in_stage : Error<"attribute %0 is unsupported in 
%select{Pixel|Vertex|Geometry|Hull|Domain|Compute|Library|RayGeneration|Intersection|AnyHit|ClosestHit|Miss|Callable|Mesh|Amplification|Invalid}1
 shaders, requires %2">;
+
+def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to 
numthreads attribute cannot exceed %1">;
+def err_hlsl_numthreads_invalid : Error<"total number of threads cannot exceed 
%0">;
+
 } // end of sema component.
+

diff  --git a/clang/include/clang/Parse/Parser.h 
b/clang/include/clang/Parse/Parser.h
index 1af53a151f8c9..3241931345297 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -2783,7 +2783,8 @@ class Parser : public CodeCompletionHandler {
       const IdentifierInfo *EnclosingScope = nullptr);
 
   void MaybeParseMicrosoftAttributes(ParsedAttributes &Attrs) {
-    if (getLangOpts().MicrosoftExt && Tok.is(tok::l_square)) {
+    if ((getLangOpts().MicrosoftExt || getLangOpts().HLSL) &&
+        Tok.is(tok::l_square)) {
       ParsedAttributes AttrsWithRange(AttrFactory);
       ParseMicrosoftAttributes(AttrsWithRange);
       Attrs.takeAllFrom(AttrsWithRange);

diff  --git a/clang/lib/Parse/ParseDeclCXX.cpp 
b/clang/lib/Parse/ParseDeclCXX.cpp
index 342ee896bee18..553dcba94fed6 100644
--- a/clang/lib/Parse/ParseDeclCXX.cpp
+++ b/clang/lib/Parse/ParseDeclCXX.cpp
@@ -4302,10 +4302,19 @@ bool Parser::ParseCXX11AttributeArgs(IdentifierInfo 
*AttrName,
   ParsedAttr::Syntax Syntax =
       LO.CPlusPlus ? ParsedAttr::AS_CXX11 : ParsedAttr::AS_C2x;
 
+  // Try parsing microsoft attributes
+  if (getLangOpts().MicrosoftExt || getLangOpts().HLSL) {
+    if (hasAttribute(AttrSyntax::Microsoft, ScopeName, AttrName,
+                     getTargetInfo(), getLangOpts()))
+      Syntax = ParsedAttr::AS_Microsoft;
+  }
+
   // If the attribute isn't known, we will not attempt to parse any
   // arguments.
-  if (!hasAttribute(LO.CPlusPlus ? AttrSyntax::CXX : AttrSyntax::C, ScopeName,
+  if (Syntax != ParsedAttr::AS_Microsoft &&
+      !hasAttribute(LO.CPlusPlus ? AttrSyntax::CXX : AttrSyntax::C, ScopeName,
                     AttrName, getTargetInfo(), getLangOpts())) {
+    if (getLangOpts().MicrosoftExt || getLangOpts().HLSL) {}
     // Eat the left paren, then skip to the ending right paren.
     ConsumeParen();
     SkipUntil(tok::r_paren);
@@ -4688,8 +4697,17 @@ void Parser::ParseMicrosoftAttributes(ParsedAttributes 
&Attrs) {
         break;
       if (Tok.getIdentifierInfo()->getName() == "uuid")
         ParseMicrosoftUuidAttributeArgs(Attrs);
-      else
+      else {
+        IdentifierInfo *II = Tok.getIdentifierInfo();
+        SourceLocation NameLoc = Tok.getLocation();
         ConsumeToken();
+        if (Tok.is(tok::l_paren)) {
+          CachedTokens OpenMPTokens;
+          ParseCXX11AttributeArgs(II, NameLoc, Attrs, &EndLoc, nullptr,
+                                  SourceLocation(), OpenMPTokens);
+          ReplayOpenMPAttributeTokens(OpenMPTokens);
+        } // FIXME: handle attributes that don't have arguments
+      }
     }
 
     T.consumeClose();

diff  --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 298d4fc17617b..b913f805bc877 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -11323,6 +11323,11 @@ void Sema::CheckMain(FunctionDecl* FD, const DeclSpec& 
DS) {
     return;
   }
 
+  // Functions named main in hlsl are default entries, but don't have specific
+  // signatures they are required to conform to.
+  if (getLangOpts().HLSL)
+    return;
+
   QualType T = FD->getType();
   assert(T->isFunctionType() && "function decl is not of function type");
   const FunctionType* FT = T->castAs<FunctionType>();

diff  --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 6788c92ab9828..87e16635f3021 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -24,6 +24,7 @@
 #include "clang/AST/Type.h"
 #include "clang/Basic/CharInfo.h"
 #include "clang/Basic/DarwinSDKInfo.h"
+#include "clang/Basic/LangOptions.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/SourceManager.h"
 #include "clang/Basic/TargetBuiltins.h"
@@ -6836,6 +6837,64 @@ static void handleUuidAttr(Sema &S, Decl *D, const 
ParsedAttr &AL) {
     D->addAttr(UA);
 }
 
+static void handleHLSLNumThreadsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
+  using llvm::Triple;
+  Triple Target = S.Context.getTargetInfo().getTriple();
+  if (!llvm::is_contained({Triple::Compute, Triple::Mesh, 
Triple::Amplification,
+                           Triple::Library},
+                          Target.getEnvironment())) {
+    uint32_t Pipeline =
+        (uint32_t)S.Context.getTargetInfo().getTriple().getEnvironment() -
+        (uint32_t)llvm::Triple::Pixel;
+    S.Diag(AL.getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
+        << AL << Pipeline << "Compute, Amplification, Mesh or Library";
+    return;
+  }
+
+  llvm::VersionTuple SMVersion = Target.getOSVersion();
+  uint32_t ZMax = 1024;
+  uint32_t ThreadMax = 1024;
+  if (SMVersion.getMajor() <= 4) {
+    ZMax = 1;
+    ThreadMax = 768;
+  } else if (SMVersion.getMajor() == 5) {
+    ZMax = 64;
+    ThreadMax = 1024;
+  }
+
+  uint32_t X;
+  if (!checkUInt32Argument(S, AL, AL.getArgAsExpr(0), X))
+    return;
+  if (X > 1024) {
+    S.Diag(AL.getArgAsExpr(0)->getExprLoc(),
+           diag::err_hlsl_numthreads_argument_oor) << 0 << 1024;
+    return;
+  }
+  uint32_t Y;
+  if (!checkUInt32Argument(S, AL, AL.getArgAsExpr(1), Y))
+    return;
+  if (Y > 1024) {
+    S.Diag(AL.getArgAsExpr(1)->getExprLoc(),
+           diag::err_hlsl_numthreads_argument_oor) << 1 << 1024;
+    return;
+  }
+  uint32_t Z;
+  if (!checkUInt32Argument(S, AL, AL.getArgAsExpr(2), Z))
+    return;
+  if (Z > ZMax) {
+    S.Diag(AL.getArgAsExpr(2)->getExprLoc(),
+           diag::err_hlsl_numthreads_argument_oor) << 2 << ZMax;
+    return;
+  }
+
+  if (X * Y * Z > ThreadMax) {
+    S.Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax;
+    return;
+  }
+
+  D->addAttr(::new (S.Context) HLSLNumThreadsAttr(S.Context, AL, X, Y, Z));
+}
+
 static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   if (!S.LangOpts.CPlusPlus) {
     S.Diag(AL.getLoc(), diag::err_attribute_not_supported_in_lang)
@@ -8697,6 +8756,11 @@ static void ProcessDeclAttribute(Sema &S, Scope *scope, 
Decl *D,
   case ParsedAttr::AT_Thread:
     handleDeclspecThreadAttr(S, D, AL);
     break;
+  
+  // HLSL attributes:
+  case ParsedAttr::AT_HLSLNumThreads:
+    handleHLSLNumThreadsAttr(S, D, AL);
+    break;
 
   case ParsedAttr::AT_AbiTag:
     handleAbiTagAttr(S, D, AL);

diff  --git a/clang/test/SemaHLSL/lit.local.cfg 
b/clang/test/SemaHLSL/lit.local.cfg
new file mode 100644
index 0000000000000..d637d6d68030d
--- /dev/null
+++ b/clang/test/SemaHLSL/lit.local.cfg
@@ -0,0 +1 @@
+config.suffixes = ['.hlsl']

diff  --git a/clang/test/SemaHLSL/num_threads.hlsl 
b/clang/test/SemaHLSL/num_threads.hlsl
new file mode 100644
index 0000000000000..46e4fa131aa75
--- /dev/null
+++ b/clang/test/SemaHLSL/num_threads.hlsl
@@ -0,0 +1,49 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o 
- %s | FileCheck %s 
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -x hlsl -ast-dump -o - 
%s | FileCheck %s 
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-amplification -x hlsl 
-ast-dump -o - %s | FileCheck %s 
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -o 
- %s | FileCheck %s 
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-pixel -x hlsl -ast-dump -o - 
%s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-vertex -x hlsl -ast-dump -o 
- %s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-hull -x hlsl -ast-dump -o - 
%s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-domain -x hlsl -ast-dump -o 
- %s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o 
- %s -DFAIL -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel5.0-compute -x hlsl -ast-dump -o 
- %s -DFAIL -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel4.0-compute -x hlsl -ast-dump -o 
- %s -DFAIL -verify
+
+#if __SHADER_TARGET_STAGE == __SHADER_STAGE_COMPUTE || __SHADER_TARGET_STAGE 
== __SHADER_STAGE_MESH || __SHADER_TARGET_STAGE == __SHADER_STAGE_AMPLIFICATION 
|| __SHADER_TARGET_STAGE == __SHADER_STAGE_LIBRARY
+#ifdef FAIL
+#if __SHADER_TARGET_MAJOR == 6
+// expected-error@+1 {{'numthreads' attribute requires an integer constant}}
+[numthreads("1",2,3)]
+// expected-error@+1 {{argument 'X' to numthreads attribute cannot exceed 
1024}}
+[numthreads(-1,2,3)]
+// expected-error@+1 {{argument 'Y' to numthreads attribute cannot exceed 
1024}}
+[numthreads(1,-2,3)]
+// expected-error@+1 {{argument 'Z' to numthreads attribute cannot exceed 
1024}}
+[numthreads(1,2,-3)]
+// expected-error@+1 {{total number of threads cannot exceed 1024}}
+[numthreads(1024,1024,1024)]
+#elif __SHADER_TARGET_MAJOR == 5
+// expected-error@+1 {{argument 'Z' to numthreads attribute cannot exceed 64}}
+[numthreads(1,2,68)]
+#else
+// expected-error@+1 {{argument 'Z' to numthreads attribute cannot exceed 1}}
+[numthreads(1,2,2)]
+// expected-error@+1 {{total number of threads cannot exceed 768}}
+[numthreads(1024,1,1)]
+#endif
+#endif
+// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:{{[0-9]+}}:2, col:18> 1 
2 1
+[numthreads(1,2,1)]
+int entry() {
+ return 1;
+}
+#else
+// expected-error-re@+1 {{attribute 'numthreads' is unsupported in 
{{[A-Za-z]+}} shaders, requires Compute, Amplification, Mesh or Library}}
+[numthreads(1,1,1)]
+int main() {
+ return 1;
+}
+#endif
+
+


        
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to