https://github.com/SamTebbs33 created https://github.com/llvm/llvm-project/pull/74064
This PR adds a warning that's emitted when a non-streaming or non-streaming-compatible builtin is called in an unsuitable function. Uses work by Kerry McLaughlin. >From f6a990a000b555d7f8ef0b2a99e3fea98420e899 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs <samuel.te...@arm.com> Date: Thu, 30 Nov 2023 13:42:50 +0000 Subject: [PATCH] [AArch64][SME] Warn when using a streaming builtin from a non-streaming function This PR adds a warning that's emitted when a non-streaming or non-streaming-compatible builtin is called in an unsuitable function. Uses work by Kerry McLaughlin. --- clang/include/clang/Basic/CMakeLists.txt | 6 + .../clang/Basic/DiagnosticSemaKinds.td | 3 + clang/include/clang/Sema/Sema.h | 3 + clang/lib/Sema/SemaChecking.cpp | 191 ++++++++++++++++++ .../Sema/aarch64-incompat-sm-builtin-calls.c | 21 ++ clang/utils/TableGen/SveEmitter.cpp | 68 +++++++ clang/utils/TableGen/TableGen.cpp | 9 + clang/utils/TableGen/TableGenBackends.h | 1 + 8 files changed, 302 insertions(+) diff --git a/clang/include/clang/Basic/CMakeLists.txt b/clang/include/clang/Basic/CMakeLists.txt index 085e316fcc671df..bdd72d1d63c431b 100644 --- a/clang/include/clang/Basic/CMakeLists.txt +++ b/clang/include/clang/Basic/CMakeLists.txt @@ -97,6 +97,12 @@ clang_tablegen(arm_sme_builtin_cg.inc -gen-arm-sme-builtin-codegen clang_tablegen(arm_sme_sema_rangechecks.inc -gen-arm-sme-sema-rangechecks SOURCE arm_sme.td TARGET ClangARMSmeSemaRangeChecks) +clang_tablegen(arm_sme_streaming_attrs.inc -gen-arm-sme-streaming-attrs + SOURCE arm_sme.td + TARGET ClangARMSmeStreamingAttrs) +clang_tablegen(arm_sme_builtins_za_state.inc -gen-arm-sme-builtin-za-state + SOURCE arm_sme.td + TARGET ClangARMSmeBuiltinsZAState) clang_tablegen(arm_cde_builtins.inc -gen-arm-cde-builtin-def SOURCE arm_cde.td TARGET ClangARMCdeBuiltinsDef) diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 6dfb2d7195203a3..c7036fc881a13e1 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -3151,6 +3151,9 @@ def err_attribute_arm_feature_sve_bits_unsupported : Error< def warn_attribute_arm_sm_incompat_builtin : Warning< "builtin call has undefined behaviour when called from a %0 function">, InGroup<DiagGroup<"undefined-arm-streaming">>; +def warn_attribute_arm_za_builtin_no_za_state : Warning< + "builtin call is not valid when calling from a function without active ZA state">, + InGroup<DiagGroup<"undefined-arm-za">>; def err_sve_vector_in_non_sve_target : Error< "SVE vector type %0 cannot be used in a target without sve">; def err_attribute_riscv_rvv_bits_unsupported : Error< diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 6de1a098e067a38..c13c6942c219700 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -13845,7 +13845,10 @@ class Sema final { bool CheckNeonBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID, CallExpr *TheCall); bool CheckMVEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall); + bool ParseSVEImmChecks(CallExpr *TheCall, + SmallVector<std::tuple<int, int, int>, 3> &ImmChecks); bool CheckSVEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall); + bool CheckSMEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall); bool CheckCDEBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID, CallExpr *TheCall); bool CheckARMCoprocessorImmediate(const TargetInfo &TI, const Expr *CoprocArg, diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index 77c8334f3ca25d3..f27eb8ad95cc703 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -2995,6 +2995,134 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context, enum ArmStreamingType { ArmNonStreaming, ArmStreaming, ArmStreamingCompatible }; +bool Sema::ParseSVEImmChecks( + CallExpr *TheCall, SmallVector<std::tuple<int, int, int>, 3> &ImmChecks) { + // Perform all the immediate checks for this builtin call. + bool HasError = false; + for (auto &I : ImmChecks) { + int ArgNum, CheckTy, ElementSizeInBits; + std::tie(ArgNum, CheckTy, ElementSizeInBits) = I; + + typedef bool (*OptionSetCheckFnTy)(int64_t Value); + + // Function that checks whether the operand (ArgNum) is an immediate + // that is one of the predefined values. + auto CheckImmediateInSet = [&](OptionSetCheckFnTy CheckImm, + int ErrDiag) -> bool { + // We can't check the value of a dependent argument. + Expr *Arg = TheCall->getArg(ArgNum); + if (Arg->isTypeDependent() || Arg->isValueDependent()) + return false; + + // Check constant-ness first. + llvm::APSInt Imm; + if (SemaBuiltinConstantArg(TheCall, ArgNum, Imm)) + return true; + + if (!CheckImm(Imm.getSExtValue())) + return Diag(TheCall->getBeginLoc(), ErrDiag) << Arg->getSourceRange(); + return false; + }; + + switch ((SVETypeFlags::ImmCheckType)CheckTy) { + case SVETypeFlags::ImmCheck0_31: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 31)) + HasError = true; + break; + case SVETypeFlags::ImmCheck0_13: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 13)) + HasError = true; + break; + case SVETypeFlags::ImmCheck1_16: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1, 16)) + HasError = true; + break; + case SVETypeFlags::ImmCheck0_7: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 7)) + HasError = true; + break; + case SVETypeFlags::ImmCheckExtract: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, + (2048 / ElementSizeInBits) - 1)) + HasError = true; + break; + case SVETypeFlags::ImmCheckShiftRight: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1, ElementSizeInBits)) + HasError = true; + break; + case SVETypeFlags::ImmCheckShiftRightNarrow: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1, + ElementSizeInBits / 2)) + HasError = true; + break; + case SVETypeFlags::ImmCheckShiftLeft: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, + ElementSizeInBits - 1)) + HasError = true; + break; + case SVETypeFlags::ImmCheckLaneIndex: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, + (128 / (1 * ElementSizeInBits)) - 1)) + HasError = true; + break; + case SVETypeFlags::ImmCheckLaneIndexCompRotate: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, + (128 / (2 * ElementSizeInBits)) - 1)) + HasError = true; + break; + case SVETypeFlags::ImmCheckLaneIndexDot: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, + (128 / (4 * ElementSizeInBits)) - 1)) + HasError = true; + break; + case SVETypeFlags::ImmCheckComplexRot90_270: + if (CheckImmediateInSet([](int64_t V) { return V == 90 || V == 270; }, + diag::err_rotation_argument_to_cadd)) + HasError = true; + break; + case SVETypeFlags::ImmCheckComplexRotAll90: + if (CheckImmediateInSet( + [](int64_t V) { + return V == 0 || V == 90 || V == 180 || V == 270; + }, + diag::err_rotation_argument_to_cmla)) + HasError = true; + break; + case SVETypeFlags::ImmCheck0_1: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 1)) + HasError = true; + break; + case SVETypeFlags::ImmCheck0_2: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 2)) + HasError = true; + break; + case SVETypeFlags::ImmCheck0_3: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 3)) + HasError = true; + break; + case SVETypeFlags::ImmCheck0_0: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 0)) + HasError = true; + break; + case SVETypeFlags::ImmCheck0_15: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 15)) + HasError = true; + break; + case SVETypeFlags::ImmCheck0_255: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 255)) + HasError = true; + break; + case SVETypeFlags::ImmCheck2_4_Mul2: + if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 2, 4) || + SemaBuiltinConstantArgMultiple(TheCall, ArgNum, 2)) + HasError = true; + break; + } + } + + return HasError; +} + static ArmStreamingType getArmStreamingFnType(const FunctionDecl *FD) { if (FD->hasAttr<ArmLocallyStreamingAttr>()) return ArmStreaming; @@ -3023,6 +3151,66 @@ static void checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall, << TheCall->getSourceRange() << "streaming compatible"; return; } + + if (FnType == ArmNonStreaming && BuiltinType == ArmStreaming) { + S.Diag(TheCall->getBeginLoc(), diag::warn_attribute_arm_sm_incompat_builtin) + << TheCall->getSourceRange() << "non-streaming"; + } +} + +static bool hasSMEZAState(const FunctionDecl *FD) { + if (FD->hasAttr<ArmNewZAAttr>()) + return true; + if (const auto *T = FD->getType()->getAs<FunctionProtoType>()) + if (T->getAArch64SMEAttributes() & FunctionType::SME_PStateZASharedMask) + return true; + return false; +} + +static bool hasSMEZAState(unsigned BuiltinID) { + switch (BuiltinID) { + default: + return false; +#define GET_SME_BUILTIN_HAS_ZA_STATE +#include "clang/Basic/arm_sme_builtins_za_state.inc" +#undef GET_SME_BUILTIN_HAS_ZA_STATE + } +} + +bool Sema::CheckSMEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { + if (const FunctionDecl *FD = getCurFunctionDecl()) { + bool debug = FD->getDeclName().getAsString() == "incompat_sve_sm"; + std::optional<ArmStreamingType> BuiltinType; + + switch (BuiltinID) { + default: + break; +#define GET_SME_STREAMING_ATTRS +#include "clang/Basic/arm_sme_streaming_attrs.inc" +#undef GET_SME_STREAMING_ATTRS + } + + if (BuiltinType) + checkArmStreamingBuiltin(*this, TheCall, FD, *BuiltinType); + + if (hasSMEZAState(BuiltinID) && !hasSMEZAState(FD)) + Diag(TheCall->getBeginLoc(), + diag::warn_attribute_arm_za_builtin_no_za_state) + << TheCall->getSourceRange(); + } + + // Range check SME intrinsics that take immediate values. + SmallVector<std::tuple<int, int, int>, 3> ImmChecks; + + switch (BuiltinID) { + default: + return false; +#define GET_SME_IMMEDIATE_CHECK +#include "clang/Basic/arm_sme_sema_rangechecks.inc" +#undef GET_SME_IMMEDIATE_CHECK + } + + return ParseSVEImmChecks(TheCall, ImmChecks); } bool Sema::CheckSVEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { @@ -3559,6 +3747,9 @@ bool Sema::CheckAArch64BuiltinFunctionCall(const TargetInfo &TI, if (CheckSVEBuiltinFunctionCall(BuiltinID, TheCall)) return true; + if (CheckSMEBuiltinFunctionCall(BuiltinID, TheCall)) + return true; + // For intrinsics which take an immediate value as part of the instruction, // range check them here. unsigned i = 0, l = 0, u = 0; diff --git a/clang/test/Sema/aarch64-incompat-sm-builtin-calls.c b/clang/test/Sema/aarch64-incompat-sm-builtin-calls.c index e77e09c4435188d..8f33075c7a9b51f 100644 --- a/clang/test/Sema/aarch64-incompat-sm-builtin-calls.c +++ b/clang/test/Sema/aarch64-incompat-sm-builtin-calls.c @@ -5,6 +5,7 @@ // REQUIRES: aarch64-registered-target #include "arm_neon.h" +#include "arm_sme_draft_spec_subject_to_change.h" int16x8_t incompat_neon_sm(int16x8_t splat) __arm_streaming { // expected-warning@+1 {{builtin call has undefined behaviour when called from a streaming function}} @@ -20,3 +21,23 @@ int16x8_t incompat_neon_smc(int16x8_t splat) __arm_streaming_compatible { // expected-warning@+1 {{builtin call has undefined behaviour when called from a streaming compatible function}} return (int16x8_t)__builtin_neon_vqaddq_v((int8x16_t)splat, (int8x16_t)splat, 33); } + +void incompat_sme_norm(svbool_t pg, void const *ptr) __arm_shared_za { + // expected-warning@+1 {{builtin call has undefined behaviour when called from a non-streaming function}} + return __builtin_sme_svld1_hor_za128(0, 0, pg, ptr); +} + +void incompat_sme_smc(svbool_t pg, void const *ptr) __arm_streaming_compatible __arm_shared_za { + // expected-warning@+1 {{builtin call has undefined behaviour when called from a streaming compatible function}} + return __builtin_sme_svld1_hor_za128(0, 0, pg, ptr); +} + +void incompat_sme_sm(svbool_t pn, svbool_t pm, svfloat32_t zn, svfloat32_t zm) __arm_shared_za { + // expected-warning@+1 {{builtin call has undefined behaviour when called from a non-streaming function}} + svmops_za32_f32_m(0, pn, pm, zn, zm); +} + +svbool_t streaming_caller_ptrue(void) __arm_streaming { + // expected-no-warning + return svand_z(svptrue_b16(), svptrue_pat_b16(SV_ALL), svptrue_pat_b16(SV_VL4)); +} diff --git a/clang/utils/TableGen/SveEmitter.cpp b/clang/utils/TableGen/SveEmitter.cpp index b380bd9dfe6643a..cb02ef09e130eb3 100644 --- a/clang/utils/TableGen/SveEmitter.cpp +++ b/clang/utils/TableGen/SveEmitter.cpp @@ -378,6 +378,9 @@ class SVEEmitter { /// Emit all the information needed to map builtin -> LLVM IR intrinsic. void createSMECodeGenMap(raw_ostream &o); + /// Create a table for a builtin's requirement for PSTATE.SM. + void createStreamingAttrs(raw_ostream &o, ACLEKind Kind); + /// Emit all the range checks for the immediates. void createSMERangeChecks(raw_ostream &o); @@ -1369,6 +1372,12 @@ void SVEEmitter::createHeader(raw_ostream &OS) { OS << "#define __aio static __inline__ __attribute__((__always_inline__, " "__nodebug__, __overloadable__))\n\n"; + OS << "#ifdef __ARM_FEATURE_SME\n"; + OS << "#define __asc __attribute__((arm_streaming_compatible))\n"; + OS << "#else\n"; + OS << "#define __asc\n"; + OS << "#endif\n\n"; + // Add reinterpret functions. for (auto [N, Suffix] : std::initializer_list<std::pair<unsigned, const char *>>{ @@ -1688,6 +1697,61 @@ void SVEEmitter::createSMERangeChecks(raw_ostream &OS) { OS << "#endif\n\n"; } +void SVEEmitter::createStreamingAttrs(raw_ostream &OS, ACLEKind Kind) { + std::vector<Record *> RV = Records.getAllDerivedDefinitions("Inst"); + SmallVector<std::unique_ptr<Intrinsic>, 128> Defs; + for (auto *R : RV) + createIntrinsic(R, Defs); + + // The mappings must be sorted based on BuiltinID. + llvm::sort(Defs, [](const std::unique_ptr<Intrinsic> &A, + const std::unique_ptr<Intrinsic> &B) { + return A->getMangledName() < B->getMangledName(); + }); + + switch (Kind) { + case ACLEKind::SME: + OS << "#ifdef GET_SME_STREAMING_ATTRS\n"; + break; + case ACLEKind::SVE: + OS << "#ifdef GET_SVE_STREAMING_ATTRS\n"; + break; + } + + // Ensure these are only emitted once. + std::set<std::string> Emitted; + + uint64_t IsStreamingFlag = getEnumValueForFlag("IsStreaming"); + uint64_t IsStreamingCompatibleFlag = + getEnumValueForFlag("IsStreamingCompatible"); + for (auto &Def : Defs) { + if (Emitted.find(Def->getMangledName()) != Emitted.end()) + continue; + + switch (Kind) { + case ACLEKind::SME: + OS << "case SME::BI__builtin_sme_"; + break; + case ACLEKind::SVE: + OS << "case SVE::BI__builtin_sve_"; + break; + } + OS << Def->getMangledName() << ":\n"; + + if (Def->isFlagSet(IsStreamingFlag)) + OS << " BuiltinType = ArmStreaming;\n"; + else if (Def->isFlagSet(IsStreamingCompatibleFlag)) + OS << " BuiltinType = ArmStreamingCompatible;\n"; + else + OS << " BuiltinType = ArmNonStreaming;\n"; + OS << " break;\n"; + + Emitted.insert(Def->getMangledName()); + } + + OS << "#endif\n\n"; +} + namespace clang { void EmitSveHeader(RecordKeeper &Records, raw_ostream &OS) { SVEEmitter(Records).createHeader(OS); @@ -1724,4 +1788,8 @@ void EmitSmeBuiltinCG(RecordKeeper &Records, raw_ostream &OS) { void EmitSmeRangeChecks(RecordKeeper &Records, raw_ostream &OS) { SVEEmitter(Records).createSMERangeChecks(OS); } + +void EmitSmeStreamingAttrs(RecordKeeper &Records, raw_ostream &OS) { + SVEEmitter(Records).createStreamingAttrs(OS, ACLEKind::SME); +} } // End namespace clang diff --git a/clang/utils/TableGen/TableGen.cpp b/clang/utils/TableGen/TableGen.cpp index 7efb6c731d3e5ee..9ba2fb07f1380da 100644 --- a/clang/utils/TableGen/TableGen.cpp +++ b/clang/utils/TableGen/TableGen.cpp @@ -89,6 +89,8 @@ enum ActionType { GenArmSmeBuiltins, GenArmSmeBuiltinCG, GenArmSmeRangeChecks, + GenArmSmeStreamingAttrs, + GenArmSmeBuiltinZAState, GenArmCdeHeader, GenArmCdeBuiltinDef, GenArmCdeBuiltinSema, @@ -251,6 +253,10 @@ cl::opt<ActionType> Action( "Generate arm_sme_builtin_cg_map.inc for clang"), clEnumValN(GenArmSmeRangeChecks, "gen-arm-sme-sema-rangechecks", "Generate arm_sme_sema_rangechecks.inc for clang"), + clEnumValN(GenArmSmeStreamingAttrs, "gen-arm-sme-streaming-attrs", + "Generate arm_sme_streaming_attrs.inc for clang"), + clEnumValN(GenArmSmeBuiltinZAState, "gen-arm-sme-builtin-za-state", + "Generate arm_sme_builtins_za_state.inc for clang"), clEnumValN(GenArmMveHeader, "gen-arm-mve-header", "Generate arm_mve.h for clang"), clEnumValN(GenArmMveBuiltinDef, "gen-arm-mve-builtin-def", @@ -500,6 +506,9 @@ bool ClangTableGenMain(raw_ostream &OS, RecordKeeper &Records) { case GenArmSmeRangeChecks: EmitSmeRangeChecks(Records, OS); break; + case GenArmSmeStreamingAttrs: + EmitSmeStreamingAttrs(Records, OS); + break; case GenArmCdeHeader: EmitCdeHeader(Records, OS); break; diff --git a/clang/utils/TableGen/TableGenBackends.h b/clang/utils/TableGen/TableGenBackends.h index d8f447069376bca..2f1c96bfa59640c 100644 --- a/clang/utils/TableGen/TableGenBackends.h +++ b/clang/utils/TableGen/TableGenBackends.h @@ -109,6 +109,7 @@ void EmitSmeHeader(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); void EmitSmeBuiltins(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); void EmitSmeBuiltinCG(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); void EmitSmeRangeChecks(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); +void EmitSmeStreamingAttrs(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); void EmitMveHeader(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); void EmitMveBuiltinDef(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits