https://github.com/SamTebbs33 created 
https://github.com/llvm/llvm-project/pull/74064

This PR adds a warning that's emitted when a non-streaming or 
non-streaming-compatible builtin is called in an unsuitable function.

Uses work by Kerry McLaughlin.

>From f6a990a000b555d7f8ef0b2a99e3fea98420e899 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.te...@arm.com>
Date: Thu, 30 Nov 2023 13:42:50 +0000
Subject: [PATCH] [AArch64][SME] Warn when using a streaming builtin from a
 non-streaming function

This PR adds a warning that's emitted when a non-streaming or
non-streaming-compatible builtin is called in an unsuitable function.

Uses work by Kerry McLaughlin.
---
 clang/include/clang/Basic/CMakeLists.txt      |   6 +
 .../clang/Basic/DiagnosticSemaKinds.td        |   3 +
 clang/include/clang/Sema/Sema.h               |   3 +
 clang/lib/Sema/SemaChecking.cpp               | 191 ++++++++++++++++++
 .../Sema/aarch64-incompat-sm-builtin-calls.c  |  21 ++
 clang/utils/TableGen/SveEmitter.cpp           |  68 +++++++
 clang/utils/TableGen/TableGen.cpp             |   9 +
 clang/utils/TableGen/TableGenBackends.h       |   1 +
 8 files changed, 302 insertions(+)

diff --git a/clang/include/clang/Basic/CMakeLists.txt 
b/clang/include/clang/Basic/CMakeLists.txt
index 085e316fcc671df..bdd72d1d63c431b 100644
--- a/clang/include/clang/Basic/CMakeLists.txt
+++ b/clang/include/clang/Basic/CMakeLists.txt
@@ -97,6 +97,12 @@ clang_tablegen(arm_sme_builtin_cg.inc 
-gen-arm-sme-builtin-codegen
 clang_tablegen(arm_sme_sema_rangechecks.inc -gen-arm-sme-sema-rangechecks
   SOURCE arm_sme.td
   TARGET ClangARMSmeSemaRangeChecks)
+clang_tablegen(arm_sme_streaming_attrs.inc -gen-arm-sme-streaming-attrs
+  SOURCE arm_sme.td
+  TARGET ClangARMSmeStreamingAttrs)
+clang_tablegen(arm_sme_builtins_za_state.inc -gen-arm-sme-builtin-za-state
+  SOURCE arm_sme.td
+  TARGET ClangARMSmeBuiltinsZAState)
 clang_tablegen(arm_cde_builtins.inc -gen-arm-cde-builtin-def
   SOURCE arm_cde.td
   TARGET ClangARMCdeBuiltinsDef)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td 
b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 6dfb2d7195203a3..c7036fc881a13e1 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -3151,6 +3151,9 @@ def err_attribute_arm_feature_sve_bits_unsupported : 
Error<
 def warn_attribute_arm_sm_incompat_builtin : Warning<
   "builtin call has undefined behaviour when called from a %0 function">,
   InGroup<DiagGroup<"undefined-arm-streaming">>;
+def warn_attribute_arm_za_builtin_no_za_state : Warning<
+  "builtin call is not valid when calling from a function without active ZA 
state">,
+  InGroup<DiagGroup<"undefined-arm-za">>;
 def err_sve_vector_in_non_sve_target : Error<
   "SVE vector type %0 cannot be used in a target without sve">;
 def err_attribute_riscv_rvv_bits_unsupported : Error<
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 6de1a098e067a38..c13c6942c219700 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -13845,7 +13845,10 @@ class Sema final {
   bool CheckNeonBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
                                     CallExpr *TheCall);
   bool CheckMVEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
+  bool ParseSVEImmChecks(CallExpr *TheCall,
+                         SmallVector<std::tuple<int, int, int>, 3> &ImmChecks);
   bool CheckSVEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
+  bool CheckSMEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
   bool CheckCDEBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
                                    CallExpr *TheCall);
   bool CheckARMCoprocessorImmediate(const TargetInfo &TI, const Expr 
*CoprocArg,
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 77c8334f3ca25d3..f27eb8ad95cc703 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2995,6 +2995,134 @@ static QualType getNeonEltType(NeonTypeFlags Flags, 
ASTContext &Context,
 
 enum ArmStreamingType { ArmNonStreaming, ArmStreaming, ArmStreamingCompatible 
};
 
+bool Sema::ParseSVEImmChecks(
+    CallExpr *TheCall, SmallVector<std::tuple<int, int, int>, 3> &ImmChecks) {
+  // Perform all the immediate checks for this builtin call.
+  bool HasError = false;
+  for (auto &I : ImmChecks) {
+    int ArgNum, CheckTy, ElementSizeInBits;
+    std::tie(ArgNum, CheckTy, ElementSizeInBits) = I;
+
+    typedef bool (*OptionSetCheckFnTy)(int64_t Value);
+
+    // Function that checks whether the operand (ArgNum) is an immediate
+    // that is one of the predefined values.
+    auto CheckImmediateInSet = [&](OptionSetCheckFnTy CheckImm,
+                                   int ErrDiag) -> bool {
+      // We can't check the value of a dependent argument.
+      Expr *Arg = TheCall->getArg(ArgNum);
+      if (Arg->isTypeDependent() || Arg->isValueDependent())
+        return false;
+
+      // Check constant-ness first.
+      llvm::APSInt Imm;
+      if (SemaBuiltinConstantArg(TheCall, ArgNum, Imm))
+        return true;
+
+      if (!CheckImm(Imm.getSExtValue()))
+        return Diag(TheCall->getBeginLoc(), ErrDiag) << Arg->getSourceRange();
+      return false;
+    };
+
+    switch ((SVETypeFlags::ImmCheckType)CheckTy) {
+    case SVETypeFlags::ImmCheck0_31:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 31))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheck0_13:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 13))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheck1_16:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1, 16))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheck0_7:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 7))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheckExtract:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
+                                      (2048 / ElementSizeInBits) - 1))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheckShiftRight:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1, ElementSizeInBits))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheckShiftRightNarrow:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1,
+                                      ElementSizeInBits / 2))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheckShiftLeft:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
+                                      ElementSizeInBits - 1))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheckLaneIndex:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
+                                      (128 / (1 * ElementSizeInBits)) - 1))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheckLaneIndexCompRotate:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
+                                      (128 / (2 * ElementSizeInBits)) - 1))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheckLaneIndexDot:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
+                                      (128 / (4 * ElementSizeInBits)) - 1))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheckComplexRot90_270:
+      if (CheckImmediateInSet([](int64_t V) { return V == 90 || V == 270; },
+                              diag::err_rotation_argument_to_cadd))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheckComplexRotAll90:
+      if (CheckImmediateInSet(
+              [](int64_t V) {
+                return V == 0 || V == 90 || V == 180 || V == 270;
+              },
+              diag::err_rotation_argument_to_cmla))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheck0_1:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 1))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheck0_2:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 2))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheck0_3:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 3))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheck0_0:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 0))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheck0_15:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 15))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheck0_255:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 255))
+        HasError = true;
+      break;
+    case SVETypeFlags::ImmCheck2_4_Mul2:
+      if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 2, 4) ||
+          SemaBuiltinConstantArgMultiple(TheCall, ArgNum, 2))
+        HasError = true;
+      break;
+    }
+  }
+
+  return HasError;
+}
+
 static ArmStreamingType getArmStreamingFnType(const FunctionDecl *FD) {
   if (FD->hasAttr<ArmLocallyStreamingAttr>())
     return ArmStreaming;
@@ -3023,6 +3151,66 @@ static void checkArmStreamingBuiltin(Sema &S, CallExpr 
*TheCall,
         << TheCall->getSourceRange() << "streaming compatible";
     return;
   }
+
+  if (FnType == ArmNonStreaming && BuiltinType == ArmStreaming) {
+    S.Diag(TheCall->getBeginLoc(), 
diag::warn_attribute_arm_sm_incompat_builtin)
+        << TheCall->getSourceRange() << "non-streaming";
+  }
+}
+
+static bool hasSMEZAState(const FunctionDecl *FD) {
+  if (FD->hasAttr<ArmNewZAAttr>())
+    return true;
+  if (const auto *T = FD->getType()->getAs<FunctionProtoType>())
+    if (T->getAArch64SMEAttributes() & FunctionType::SME_PStateZASharedMask)
+      return true;
+  return false;
+}
+
+static bool hasSMEZAState(unsigned BuiltinID) {
+  switch (BuiltinID) {
+  default:
+    return false;
+#define GET_SME_BUILTIN_HAS_ZA_STATE
+#include "clang/Basic/arm_sme_builtins_za_state.inc"
+#undef GET_SME_BUILTIN_HAS_ZA_STATE
+  }
+}
+
+bool Sema::CheckSMEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
+  if (const FunctionDecl *FD = getCurFunctionDecl()) {
+    bool debug = FD->getDeclName().getAsString() == "incompat_sve_sm";
+    std::optional<ArmStreamingType> BuiltinType;
+
+    switch (BuiltinID) {
+    default:
+      break;
+#define GET_SME_STREAMING_ATTRS
+#include "clang/Basic/arm_sme_streaming_attrs.inc"
+#undef GET_SME_STREAMING_ATTRS
+    }
+
+    if (BuiltinType)
+      checkArmStreamingBuiltin(*this, TheCall, FD, *BuiltinType);
+
+    if (hasSMEZAState(BuiltinID) && !hasSMEZAState(FD))
+      Diag(TheCall->getBeginLoc(),
+           diag::warn_attribute_arm_za_builtin_no_za_state)
+          << TheCall->getSourceRange();
+  }
+
+  // Range check SME intrinsics that take immediate values.
+  SmallVector<std::tuple<int, int, int>, 3> ImmChecks;
+
+  switch (BuiltinID) {
+  default:
+    return false;
+#define GET_SME_IMMEDIATE_CHECK
+#include "clang/Basic/arm_sme_sema_rangechecks.inc"
+#undef GET_SME_IMMEDIATE_CHECK
+  }
+
+  return ParseSVEImmChecks(TheCall, ImmChecks);
 }
 
 bool Sema::CheckSVEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -3559,6 +3747,9 @@ bool Sema::CheckAArch64BuiltinFunctionCall(const 
TargetInfo &TI,
   if (CheckSVEBuiltinFunctionCall(BuiltinID, TheCall))
     return true;
 
+  if (CheckSMEBuiltinFunctionCall(BuiltinID, TheCall))
+    return true;
+
   // For intrinsics which take an immediate value as part of the instruction,
   // range check them here.
   unsigned i = 0, l = 0, u = 0;
diff --git a/clang/test/Sema/aarch64-incompat-sm-builtin-calls.c 
b/clang/test/Sema/aarch64-incompat-sm-builtin-calls.c
index e77e09c4435188d..8f33075c7a9b51f 100644
--- a/clang/test/Sema/aarch64-incompat-sm-builtin-calls.c
+++ b/clang/test/Sema/aarch64-incompat-sm-builtin-calls.c
@@ -5,6 +5,7 @@
 // REQUIRES: aarch64-registered-target
 
 #include "arm_neon.h"
+#include "arm_sme_draft_spec_subject_to_change.h"
 
 int16x8_t incompat_neon_sm(int16x8_t splat) __arm_streaming {
   // expected-warning@+1 {{builtin call has undefined behaviour when called 
from a streaming function}}
@@ -20,3 +21,23 @@ int16x8_t incompat_neon_smc(int16x8_t splat) 
__arm_streaming_compatible {
   // expected-warning@+1 {{builtin call has undefined behaviour when called 
from a streaming compatible function}}
   return (int16x8_t)__builtin_neon_vqaddq_v((int8x16_t)splat, 
(int8x16_t)splat, 33);
 }
+
+void incompat_sme_norm(svbool_t pg, void const *ptr) __arm_shared_za {
+  // expected-warning@+1 {{builtin call has undefined behaviour when called 
from a non-streaming function}}
+  return __builtin_sme_svld1_hor_za128(0, 0, pg, ptr);
+}
+
+void incompat_sme_smc(svbool_t pg, void const *ptr) __arm_streaming_compatible 
__arm_shared_za {
+  // expected-warning@+1 {{builtin call has undefined behaviour when called 
from a streaming compatible function}}
+  return __builtin_sme_svld1_hor_za128(0, 0, pg, ptr);
+}
+
+void incompat_sme_sm(svbool_t pn, svbool_t pm, svfloat32_t zn, svfloat32_t zm) 
__arm_shared_za {
+  // expected-warning@+1 {{builtin call has undefined behaviour when called 
from a non-streaming function}}
+  svmops_za32_f32_m(0, pn, pm, zn, zm);
+}
+
+svbool_t streaming_caller_ptrue(void) __arm_streaming {
+  // expected-no-warning
+  return svand_z(svptrue_b16(), svptrue_pat_b16(SV_ALL), 
svptrue_pat_b16(SV_VL4));
+}
diff --git a/clang/utils/TableGen/SveEmitter.cpp 
b/clang/utils/TableGen/SveEmitter.cpp
index b380bd9dfe6643a..cb02ef09e130eb3 100644
--- a/clang/utils/TableGen/SveEmitter.cpp
+++ b/clang/utils/TableGen/SveEmitter.cpp
@@ -378,6 +378,9 @@ class SVEEmitter {
   /// Emit all the information needed to map builtin -> LLVM IR intrinsic.
   void createSMECodeGenMap(raw_ostream &o);
 
+  /// Create a table for a builtin's requirement for PSTATE.SM.
+  void createStreamingAttrs(raw_ostream &o, ACLEKind Kind);
+
   /// Emit all the range checks for the immediates.
   void createSMERangeChecks(raw_ostream &o);
 
@@ -1369,6 +1372,12 @@ void SVEEmitter::createHeader(raw_ostream &OS) {
   OS << "#define __aio static __inline__ __attribute__((__always_inline__, "
         "__nodebug__, __overloadable__))\n\n";
 
+  OS << "#ifdef __ARM_FEATURE_SME\n";
+  OS << "#define __asc __attribute__((arm_streaming_compatible))\n";
+  OS << "#else\n";
+  OS << "#define __asc\n";
+  OS << "#endif\n\n";
+
   // Add reinterpret functions.
   for (auto [N, Suffix] :
        std::initializer_list<std::pair<unsigned, const char *>>{
@@ -1688,6 +1697,61 @@ void SVEEmitter::createSMERangeChecks(raw_ostream &OS) {
   OS << "#endif\n\n";
 }
 
+void SVEEmitter::createStreamingAttrs(raw_ostream &OS, ACLEKind Kind) {
+  std::vector<Record *> RV = Records.getAllDerivedDefinitions("Inst");
+  SmallVector<std::unique_ptr<Intrinsic>, 128> Defs;
+  for (auto *R : RV)
+    createIntrinsic(R, Defs);
+
+  // The mappings must be sorted based on BuiltinID.
+  llvm::sort(Defs, [](const std::unique_ptr<Intrinsic> &A,
+                      const std::unique_ptr<Intrinsic> &B) {
+    return A->getMangledName() < B->getMangledName();
+  });
+
+  switch (Kind) {
+  case ACLEKind::SME:
+    OS << "#ifdef GET_SME_STREAMING_ATTRS\n";
+    break;
+  case ACLEKind::SVE:
+    OS << "#ifdef GET_SVE_STREAMING_ATTRS\n";
+    break;
+  }
+
+  // Ensure these are only emitted once.
+  std::set<std::string> Emitted;
+
+  uint64_t IsStreamingFlag = getEnumValueForFlag("IsStreaming");
+  uint64_t IsStreamingCompatibleFlag =
+      getEnumValueForFlag("IsStreamingCompatible");
+  for (auto &Def : Defs) {
+    if (Emitted.find(Def->getMangledName()) != Emitted.end())
+      continue;
+
+    switch (Kind) {
+    case ACLEKind::SME:
+      OS << "case SME::BI__builtin_sme_";
+      break;
+    case ACLEKind::SVE:
+      OS << "case SVE::BI__builtin_sve_";
+      break;
+    }
+    OS << Def->getMangledName() << ":\n";
+
+    if (Def->isFlagSet(IsStreamingFlag))
+      OS << "  BuiltinType = ArmStreaming;\n";
+    else if (Def->isFlagSet(IsStreamingCompatibleFlag))
+      OS << "  BuiltinType = ArmStreamingCompatible;\n";
+    else
+      OS << "  BuiltinType = ArmNonStreaming;\n";
+    OS << "  break;\n";
+
+    Emitted.insert(Def->getMangledName());
+  }
+
+  OS << "#endif\n\n";
+}
+
 namespace clang {
 void EmitSveHeader(RecordKeeper &Records, raw_ostream &OS) {
   SVEEmitter(Records).createHeader(OS);
@@ -1724,4 +1788,8 @@ void EmitSmeBuiltinCG(RecordKeeper &Records, raw_ostream 
&OS) {
 void EmitSmeRangeChecks(RecordKeeper &Records, raw_ostream &OS) {
   SVEEmitter(Records).createSMERangeChecks(OS);
 }
+
+void EmitSmeStreamingAttrs(RecordKeeper &Records, raw_ostream &OS) {
+  SVEEmitter(Records).createStreamingAttrs(OS, ACLEKind::SME);
+}
 } // End namespace clang
diff --git a/clang/utils/TableGen/TableGen.cpp 
b/clang/utils/TableGen/TableGen.cpp
index 7efb6c731d3e5ee..9ba2fb07f1380da 100644
--- a/clang/utils/TableGen/TableGen.cpp
+++ b/clang/utils/TableGen/TableGen.cpp
@@ -89,6 +89,8 @@ enum ActionType {
   GenArmSmeBuiltins,
   GenArmSmeBuiltinCG,
   GenArmSmeRangeChecks,
+  GenArmSmeStreamingAttrs,
+  GenArmSmeBuiltinZAState,
   GenArmCdeHeader,
   GenArmCdeBuiltinDef,
   GenArmCdeBuiltinSema,
@@ -251,6 +253,10 @@ cl::opt<ActionType> Action(
                    "Generate arm_sme_builtin_cg_map.inc for clang"),
         clEnumValN(GenArmSmeRangeChecks, "gen-arm-sme-sema-rangechecks",
                    "Generate arm_sme_sema_rangechecks.inc for clang"),
+        clEnumValN(GenArmSmeStreamingAttrs, "gen-arm-sme-streaming-attrs",
+                   "Generate arm_sme_streaming_attrs.inc for clang"),
+        clEnumValN(GenArmSmeBuiltinZAState, "gen-arm-sme-builtin-za-state",
+                   "Generate arm_sme_builtins_za_state.inc for clang"),
         clEnumValN(GenArmMveHeader, "gen-arm-mve-header",
                    "Generate arm_mve.h for clang"),
         clEnumValN(GenArmMveBuiltinDef, "gen-arm-mve-builtin-def",
@@ -500,6 +506,9 @@ bool ClangTableGenMain(raw_ostream &OS, RecordKeeper 
&Records) {
   case GenArmSmeRangeChecks:
     EmitSmeRangeChecks(Records, OS);
     break;
+  case GenArmSmeStreamingAttrs:
+    EmitSmeStreamingAttrs(Records, OS);
+    break;
   case GenArmCdeHeader:
     EmitCdeHeader(Records, OS);
     break;
diff --git a/clang/utils/TableGen/TableGenBackends.h 
b/clang/utils/TableGen/TableGenBackends.h
index d8f447069376bca..2f1c96bfa59640c 100644
--- a/clang/utils/TableGen/TableGenBackends.h
+++ b/clang/utils/TableGen/TableGenBackends.h
@@ -109,6 +109,7 @@ void EmitSmeHeader(llvm::RecordKeeper &Records, 
llvm::raw_ostream &OS);
 void EmitSmeBuiltins(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
 void EmitSmeBuiltinCG(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
 void EmitSmeRangeChecks(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
+void EmitSmeStreamingAttrs(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
 
 void EmitMveHeader(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
 void EmitMveBuiltinDef(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to