llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang-analysis Author: Florian Mayer (fmayer) <details> <summary>Changes</summary> --- Full diff: https://github.com/llvm/llvm-project/pull/163871.diff 2 Files Affected: - (modified) clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp (+122) - (modified) clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp (+228) ``````````diff diff --git a/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp b/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp index 11c4ad90293d9..4ebf3e4251dd6 100644 --- a/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp +++ b/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp @@ -115,6 +115,20 @@ static auto valueOperatorCall() { isStatusOrOperatorCallWithName("->"))); } +static clang::ast_matchers::TypeMatcher statusType() { + using namespace ::clang::ast_matchers; // NOLINT: Too many names + return hasCanonicalType(qualType(hasDeclaration(statusClass()))); +} + + +static auto isComparisonOperatorCall(llvm::StringRef operator_name) { + using namespace ::clang::ast_matchers; // NOLINT: Too many names + return cxxOperatorCallExpr( + hasOverloadedOperatorName(operator_name), argumentCountIs(2), + hasArgument(0, anyOf(hasType(statusType()), hasType(statusOrType()))), + hasArgument(1, anyOf(hasType(statusType()), hasType(statusOrType())))); +} + static auto buildDiagnoseMatchSwitch(const UncheckedStatusOrAccessModelOptions &Options) { return CFGMatchSwitchBuilder<const Environment, @@ -304,6 +318,100 @@ static void transferStatusUpdateCall(const CXXMemberCallExpr *Expr, State.Env.setValue(locForOk(*ThisLoc), NewVal); } +static BoolValue* evaluateStatusEquality(RecordStorageLocation& LhsStatusLoc, + RecordStorageLocation& RhsStatusLoc, + Environment& Env) { + auto& A = Env.arena(); + // Logically, a Status object is composed of an error code that could take one + // of multiple possible values, including the "ok" value. We track whether a + // Status object has an "ok" value and represent this as an `ok` bit. Equality + // of Status objects compares their error codes. Therefore, merely comparing + // the `ok` bits isn't sufficient: when two Status objects are assigned non-ok + // error codes the equality of their respective error codes matters. Since we + // only track the `ok` bits, we can't make any conclusions about equality when + // we know that two Status objects have non-ok values. + + auto& LhsOkVal = valForOk(LhsStatusLoc, Env); + auto& RhsOkVal = valForOk(RhsStatusLoc, Env); + + auto& Res = Env.makeAtomicBoolValue(); + + // lhs && rhs => res (a.k.a. !res => !lhs || !rhs) + Env.assume(A.makeImplies(A.makeAnd(LhsOkVal.formula(), RhsOkVal.formula()), + Res.formula())); + // res => (lhs == rhs) + Env.assume(A.makeImplies( + Res.formula(), A.makeEquals(LhsOkVal.formula(), RhsOkVal.formula()))); + + return &Res; +} + +static BoolValue* evaluateStatusOrEquality( + RecordStorageLocation& LhsStatusOrLoc, + RecordStorageLocation& RhsStatusOrLoc, Environment& Env) { + auto& A = Env.arena(); + // Logically, a StatusOr<T> object is composed of two values - a Status and a + // value of type T. Equality of StatusOr objects compares both values. + // Therefore, merely comparing the `ok` bits of the Status values isn't + // sufficient. When two StatusOr objects are engaged, the equality of their + // respective values of type T matters. Similarly, when two StatusOr objects + // have Status values that have non-ok error codes, the equality of the error + // codes matters. Since we only track the `ok` bits of the Status values, we + // can't make any conclusions about equality when we know that two StatusOr + // objects are engaged or when their Status values contain non-ok error codes. + auto& LhsOkVal = valForOk(locForStatus(LhsStatusOrLoc), Env); + auto& RhsOkVal = valForOk(locForStatus(RhsStatusOrLoc), Env); + auto& res = Env.makeAtomicBoolValue(); + + // res => (lhs == rhs) + Env.assume(A.makeImplies( + res.formula(), A.makeEquals(LhsOkVal.formula(), RhsOkVal.formula()))); + return &res; +} + + +static BoolValue* evaluateEquality(const Expr* LhsExpr, const Expr* RhsExpr, + Environment& Env) { + // Check the type of both sides in case an operator== is added that admits + // different types. + if (isStatusOrType(LhsExpr->getType()) && + isStatusOrType(RhsExpr->getType())) { + auto* LhsStatusOrLoc = Env.get<RecordStorageLocation>(*LhsExpr); + if (LhsStatusOrLoc == nullptr) return nullptr; + auto* RhsStatusOrLoc = Env.get<RecordStorageLocation>(*RhsExpr); + if (RhsStatusOrLoc == nullptr) return nullptr; + + return evaluateStatusOrEquality(*LhsStatusOrLoc, *RhsStatusOrLoc, Env); + + // Check the type of both sides in case an operator== is added that admits + // different types. + } + if (isStatusType(LhsExpr->getType()) && isStatusType(RhsExpr->getType())) { + auto* LhsStatusLoc = Env.get<RecordStorageLocation>(*LhsExpr); + if (LhsStatusLoc == nullptr) return nullptr; + + auto* RhsStatusLoc = Env.get<RecordStorageLocation>(*RhsExpr); + if (RhsStatusLoc == nullptr) return nullptr; + + return evaluateStatusEquality(*LhsStatusLoc, *RhsStatusLoc, Env); + } + return nullptr; +} + +static void transferComparisonOperator(const CXXOperatorCallExpr* Expr, + LatticeTransferState& State, + bool IsNegative) { + auto* LhsAndRhsVal = + evaluateEquality(Expr->getArg(0), Expr->getArg(1), State.Env); + if (LhsAndRhsVal == nullptr) return; + + if (IsNegative) + State.Env.setValue(*Expr, State.Env.makeNot(*LhsAndRhsVal)); + else + State.Env.setValue(*Expr, *LhsAndRhsVal); +} + + CFGMatchSwitch<LatticeTransferState> buildTransferMatchSwitch(ASTContext &Ctx, CFGMatchSwitchBuilder<LatticeTransferState> Builder) { @@ -317,6 +425,20 @@ buildTransferMatchSwitch(ASTContext &Ctx, transferStatusOkCall) .CaseOfCFGStmt<CXXMemberCallExpr>(isStatusMemberCallWithName("Update"), transferStatusUpdateCall) + .CaseOfCFGStmt<CXXOperatorCallExpr>( + isComparisonOperatorCall("=="), + [](const CXXOperatorCallExpr* Expr, const MatchFinder::MatchResult&, + LatticeTransferState& State) { + transferComparisonOperator(Expr, State, + /*IsNegative=*/false); + }) + .CaseOfCFGStmt<CXXOperatorCallExpr>( + isComparisonOperatorCall("!="), + [](const CXXOperatorCallExpr* Expr, const MatchFinder::MatchResult&, + LatticeTransferState& State) { + transferComparisonOperator(Expr, State, + /*IsNegative=*/true); + }) .Build(); } diff --git a/clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp b/clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp index d6e326cae11fd..99f04cc8fe7e7 100644 --- a/clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp +++ b/clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp @@ -2601,6 +2601,234 @@ TEST_P(UncheckedStatusOrAccessModelTest, StatusUpdate) { )cc"); } +TEST_P(UncheckedStatusOrAccessModelTest, EqualityCheck) { + ExpectDiagnosticsFor( + R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (x == y) + y.value(); + else + y.value(); // [[unsafe]] + } + } + )cc"); + ExpectDiagnosticsFor( + R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (y == x) + y.value(); + else + y.value(); // [[unsafe]] + } + } + )cc"); + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (x != y) + y.value(); // [[unsafe]] + else + y.value(); + } + } + )cc"); + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (y != x) + y.value(); // [[unsafe]] + else + y.value(); + } + } + )cc"); + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (!(x == y)) + y.value(); // [[unsafe]] + else + y.value(); + } + } + )cc"); + ExpectDiagnosticsFor( + R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (!(x != y)) + y.value(); + else + y.value(); // [[unsafe]] + } + } + )cc"); + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x == y) + if (x.ok()) y.value(); + } + )cc"); + ExpectDiagnosticsFor( + R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (x.status() == y.status()) + y.value(); + else + y.value(); // [[unsafe]] + } + } + )cc"); + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (x.status() != y.status()) + y.value(); // [[unsafe]] + else + y.value(); + } + } + )cc"); + ExpectDiagnosticsFor( + R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (x.ok() == y.ok()) + y.value(); + else + y.value(); // [[unsafe]] + } + } + )cc"); + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (x.ok() != y.ok()) + y.value(); // [[unsafe]] + else + y.value(); + } + } + )cc"); + ExpectDiagnosticsFor( + R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (x.status().ok() == y.status().ok()) + y.value(); + else + y.value(); // [[unsafe]] + } + } + )cc"); + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (x.status().ok() != y.status().ok()) + y.value(); // [[unsafe]] + else + y.value(); + } + } + )cc"); + ExpectDiagnosticsFor( + R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (x.status().ok() == y.ok()) + y.value(); + else + y.value(); // [[unsafe]] + } + } + )cc"); + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT x, STATUSOR_INT y) { + if (x.ok()) { + if (x.status().ok() != y.ok()) + y.value(); // [[unsafe]] + else + y.value(); + } + } + )cc"); + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(bool b, STATUSOR_INT sor) { + if (sor.ok() == b) { + if (b) sor.value(); + } + } + )cc"); + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT sor) { + if (sor.ok() == true) sor.value(); + } + )cc"); + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT sor) { + if (sor.ok() == false) sor.value(); // [[unsafe]] + } + )cc"); + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(bool b) { + STATUSOR_INT sor1; + STATUSOR_INT sor2 = Make<STATUSOR_INT>(); + if (sor1 == sor2) sor2.value(); // [[unsafe]] + } + )cc"); + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(bool b) { + STATUSOR_INT sor1 = Make<STATUSOR_INT>(); + STATUSOR_INT sor2; + if (sor1 == sor2) sor1.value(); // [[unsafe]] + } + )cc"); +} + + } // namespace std::string `````````` </details> https://github.com/llvm/llvm-project/pull/163871 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
