AlexVlx updated this revision to Diff 122323.
AlexVlx marked 7 inline comments as done.

https://reviews.llvm.org/D39857

Files:
  include/clang/Basic/Attr.td
  lib/CodeGen/TargetInfo.cpp
  lib/Sema/SemaDeclAttr.cpp

Index: lib/Sema/SemaDeclAttr.cpp
===================================================================
--- lib/Sema/SemaDeclAttr.cpp
+++ lib/Sema/SemaDeclAttr.cpp
@@ -5469,16 +5469,33 @@
   }
 }
 
+static bool checkAllAreIntegral(Sema &S, const AttributeList &Attr) {
+  for (auto i = 0u; i != Attr.getNumArgs(); ++i) {
+    Expr* E = Attr.getArgAsExpr(i);
+    if (E && !E->getType()->isIntegralOrEnumerationType()) {
+      S.Diag(getAttrLoc(Attr), diag::err_attribute_argument_n_type)
+        << getAttrName(Attr) << i << AANT_ArgumentIntegerConstant
+        << E->getSourceRange();
+
+      return false;
+    }
+  }
+
+  return true;
+}
+
 static void handleAMDGPUFlatWorkGroupSizeAttr(Sema &S, Decl *D,
                                               const AttributeList &Attr) {
   uint32_t Min = 0;
   Expr *MinExpr = Attr.getArgAsExpr(0);
-  if (!checkUInt32Argument(S, Attr, MinExpr, Min))
+  if (MinExpr->isEvaluatable(S.Context) &&
+      !checkUInt32Argument(S, Attr, MinExpr, Min))
     return;
 
   uint32_t Max = 0;
   Expr *MaxExpr = Attr.getArgAsExpr(1);
-  if (!checkUInt32Argument(S, Attr, MaxExpr, Max))
+  if (MaxExpr->isEvaluatable(S.Context) &&
+      !checkUInt32Argument(S, Attr, MaxExpr, Max))
     return;
 
   if (Min == 0 && Max != 0) {
@@ -5493,21 +5510,28 @@
   }
 
   D->addAttr(::new (S.Context)
-             AMDGPUFlatWorkGroupSizeAttr(Attr.getLoc(), S.Context, Min, Max,
-                                         Attr.getAttributeSpellingListIndex()));
+             AMDGPUFlatWorkGroupSizeAttr(
+               Attr.getLoc(), S.Context, MinExpr, MaxExpr,
+               Attr.getAttributeSpellingListIndex()));
 }
 
 static void handleAMDGPUWavesPerEUAttr(Sema &S, Decl *D,
                                        const AttributeList &Attr) {
+  if (!checkAllAreIntegral(S, Attr))
+    return;
+
   uint32_t Min = 0;
   Expr *MinExpr = Attr.getArgAsExpr(0);
-  if (!checkUInt32Argument(S, Attr, MinExpr, Min))
+  if (MinExpr->isEvaluatable(S.Context) &&
+      !checkUInt32Argument(S, Attr, MinExpr, Min))
     return;
 
   uint32_t Max = 0;
+  Expr *MaxExpr = MinExpr;
   if (Attr.getNumArgs() == 2) {
-    Expr *MaxExpr = Attr.getArgAsExpr(1);
-    if (!checkUInt32Argument(S, Attr, MaxExpr, Max))
+    MaxExpr = Attr.getArgAsExpr(1);
+    if (MaxExpr->isEvaluatable(S.Context) &&
+        !checkUInt32Argument(S, Attr, MaxExpr, Max))
       return;
   }
 
@@ -5523,31 +5547,39 @@
   }
 
   D->addAttr(::new (S.Context)
-             AMDGPUWavesPerEUAttr(Attr.getLoc(), S.Context, Min, Max,
+             AMDGPUWavesPerEUAttr(Attr.getLoc(), S.Context, MinExpr, MaxExpr,
                                   Attr.getAttributeSpellingListIndex()));
 }
 
 static void handleAMDGPUNumSGPRAttr(Sema &S, Decl *D,
                                     const AttributeList &Attr) {
+  if (!checkAllAreIntegral(S, Attr))
+    return;
+
   uint32_t NumSGPR = 0;
   Expr *NumSGPRExpr = Attr.getArgAsExpr(0);
-  if (!checkUInt32Argument(S, Attr, NumSGPRExpr, NumSGPR))
+  if (NumSGPRExpr->isEvaluatable(S.Context) &&
+      !checkUInt32Argument(S, Attr, NumSGPRExpr, NumSGPR))
     return;
 
   D->addAttr(::new (S.Context)
-             AMDGPUNumSGPRAttr(Attr.getLoc(), S.Context, NumSGPR,
+             AMDGPUNumSGPRAttr(Attr.getLoc(), S.Context, NumSGPRExpr,
                                Attr.getAttributeSpellingListIndex()));
 }
 
 static void handleAMDGPUNumVGPRAttr(Sema &S, Decl *D,
                                     const AttributeList &Attr) {
+  if (!checkAllAreIntegral(S, Attr))
+    return;
+
   uint32_t NumVGPR = 0;
   Expr *NumVGPRExpr = Attr.getArgAsExpr(0);
-  if (!checkUInt32Argument(S, Attr, NumVGPRExpr, NumVGPR))
+  if (NumVGPRExpr->isEvaluatable(S.Context) &&
+      !checkUInt32Argument(S, Attr, NumVGPRExpr, NumVGPR))
     return;
 
   D->addAttr(::new (S.Context)
-             AMDGPUNumVGPRAttr(Attr.getLoc(), S.Context, NumVGPR,
+             AMDGPUNumVGPRAttr(Attr.getLoc(), S.Context, NumVGPRExpr,
                                Attr.getAttributeSpellingListIndex()));
 }
 
Index: lib/CodeGen/TargetInfo.cpp
===================================================================
--- lib/CodeGen/TargetInfo.cpp
+++ lib/CodeGen/TargetInfo.cpp
@@ -7661,6 +7661,16 @@
 };
 }
 
+static llvm::APSInt getConstexprInt(const Expr *E, const ASTContext& Ctx)
+{
+  assert(E);
+
+  llvm::APSInt Tmp{32, 0};
+  E->EvaluateAsInt(Tmp, Ctx);
+
+  return Tmp;
+}
+
 void AMDGPUTargetCodeGenInfo::setTargetAttributes(
     const Decl *D, llvm::GlobalValue *GV, CodeGen::CodeGenModule &M,
     ForDefinition_t IsForDefinition) const {
@@ -7676,8 +7686,11 @@
     FD->getAttr<ReqdWorkGroupSizeAttr>() : nullptr;
   const auto *FlatWGS = FD->getAttr<AMDGPUFlatWorkGroupSizeAttr>();
   if (ReqdWGS || FlatWGS) {
-    unsigned Min = FlatWGS ? FlatWGS->getMin() : 0;
-    unsigned Max = FlatWGS ? FlatWGS->getMax() : 0;
+    llvm::APSInt KMin = getConstexprInt(FlatWGS->getMin(), FD->getASTContext());
+    llvm::APSInt KMax = getConstexprInt(FlatWGS->getMax(), FD->getASTContext());
+
+    unsigned Min = KMin.getZExtValue();
+    unsigned Max = std::max(KMin, KMax).getZExtValue();
     if (ReqdWGS && Min == 0 && Max == 0)
       Min = Max = ReqdWGS->getXDim() * ReqdWGS->getYDim() * ReqdWGS->getZDim();
 
@@ -7691,8 +7704,11 @@
   }
 
   if (const auto *Attr = FD->getAttr<AMDGPUWavesPerEUAttr>()) {
-    unsigned Min = Attr->getMin();
-    unsigned Max = Attr->getMax();
+    llvm::APSInt KMin = getConstexprInt(Attr->getMin(), FD->getASTContext());
+    llvm::APSInt KMax = getConstexprInt(Attr->getMax(), FD->getASTContext());
+
+    unsigned Min = KMin.getZExtValue();
+    unsigned Max = std::max(KMin, KMax).getZExtValue();
 
     if (Min != 0) {
       assert((Max == 0 || Min <= Max) && "Min must be less than or equal Max");
@@ -7706,14 +7722,18 @@
   }
 
   if (const auto *Attr = FD->getAttr<AMDGPUNumSGPRAttr>()) {
-    unsigned NumSGPR = Attr->getNumSGPR();
+    llvm::APSInt Sgprs =
+      getConstexprInt(Attr->getNumSGPR(), FD->getASTContext());
+    unsigned NumSGPR = Sgprs.getZExtValue();
 
     if (NumSGPR != 0)
       F->addFnAttr("amdgpu-num-sgpr", llvm::utostr(NumSGPR));
   }
 
   if (const auto *Attr = FD->getAttr<AMDGPUNumVGPRAttr>()) {
-    uint32_t NumVGPR = Attr->getNumVGPR();
+    llvm::APSInt Vgprs =
+      getConstexprInt(Attr->getNumVGPR(), FD->getASTContext());
+    unsigned NumVGPR = Vgprs.getZExtValue();
 
     if (NumVGPR != 0)
       F->addFnAttr("amdgpu-num-vgpr", llvm::utostr(NumVGPR));
Index: include/clang/Basic/Attr.td
===================================================================
--- include/clang/Basic/Attr.td
+++ include/clang/Basic/Attr.td
@@ -1309,30 +1309,34 @@
 
 def AMDGPUFlatWorkGroupSize : InheritableAttr {
   let Spellings = [GNU<"amdgpu_flat_work_group_size">];
-  let Args = [UnsignedArgument<"Min">, UnsignedArgument<"Max">];
+  let Args = [ExprArgument<"Min">, ExprArgument<"Max", 1>];
   let Documentation = [AMDGPUFlatWorkGroupSizeDocs];
   let Subjects = SubjectList<[Function], ErrorDiag, "ExpectedKernelFunction">;
+  let TemplateDependent = 1;
 }
 
 def AMDGPUWavesPerEU : InheritableAttr {
   let Spellings = [GNU<"amdgpu_waves_per_eu">];
-  let Args = [UnsignedArgument<"Min">, UnsignedArgument<"Max", 1>];
+  let Args = [ExprArgument<"Min">, ExprArgument<"Max", 1>];
   let Documentation = [AMDGPUWavesPerEUDocs];
   let Subjects = SubjectList<[Function], ErrorDiag, "ExpectedKernelFunction">;
+  let TemplateDependent = 1;
 }
 
 def AMDGPUNumSGPR : InheritableAttr {
   let Spellings = [GNU<"amdgpu_num_sgpr">];
-  let Args = [UnsignedArgument<"NumSGPR">];
+  let Args = [ExprArgument<"NumSGPR">];
   let Documentation = [AMDGPUNumSGPRNumVGPRDocs];
   let Subjects = SubjectList<[Function], ErrorDiag, "ExpectedKernelFunction">;
+  let TemplateDependent = 1;
 }
 
 def AMDGPUNumVGPR : InheritableAttr {
   let Spellings = [GNU<"amdgpu_num_vgpr">];
-  let Args = [UnsignedArgument<"NumVGPR">];
+  let Args = [ExprArgument<"NumVGPR">];
   let Documentation = [AMDGPUNumSGPRNumVGPRDocs];
   let Subjects = SubjectList<[Function], ErrorDiag, "ExpectedKernelFunction">;
+  let TemplateDependent = 1;
 }
 
 def NoSplitStack : InheritableAttr {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to