================ @@ -0,0 +1,268 @@ +//===- UnsafeBufferUsageExtractor.cpp -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Analysis/Scalable/Analyses/UnsafeBufferUsage/UnsafeBufferUsageExtractor.h" +#include "clang/AST/ASTConsumer.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/Decl.h" +#include "clang/AST/DynamicRecursiveASTVisitor.h" +#include "clang/AST/StmtVisitor.h" +#include "clang/Analysis/Analyses/UnsafeBufferUsage.h" +#include "clang/Analysis/Scalable/ASTEntityMapping.h" +#include "clang/Analysis/Scalable/Analyses/UnsafeBufferUsage/UnsafeBufferUsage.h" +#include "clang/Analysis/Scalable/Analyses/UnsafeBufferUsage/UnsafeBufferUsageBuilder.h" +#include "clang/Analysis/Scalable/Model/EntityId.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Error.h" +#include <memory> + +namespace { +using namespace clang; +using namespace ssaf; + +static bool hasPointerType(const Expr *E) { + auto Ty = E->getType(); + return !Ty.isNull() && !Ty->isFunctionPointerType() && + (Ty->isPointerType() || Ty->isArrayType()); +} + +constexpr inline auto buildEntityPointerLevel = + UnsafeBufferUsageTUSummaryBuilder::buildEntityPointerLevel; + +// Translate a pointer type expression 'E' to a (set of) EntityPointerLevel(s) +// associated with the declared type of the base address of `E`. +// +// The translation is a process of stripping off the pointer 'E' until its base +// address can be represented by an entity, with the number of dereferences +// tracked by incrementing the pointer level. Naturally, taking address of, as +// the inverse operation of dereference, is tracked by decrementing the pointer +// level. +// +// For example, suppose there are pointers and arrays declared as +// int *ptr, **p1, **p2; +// int arr[10][10]; +// , the translation of expressions involving these base addresses will be: +// Translate(ptr + 5) -> {(ptr, 1)} +// Translate(arr[5]) -> {(arr, 2)} +// Translate(cond ? p1[5] : p2) -> {(p1, 2), (p2, 1)} +// Translate(&arr[5]) -> {(arr, 1)} +class EntityPointerLevelTranslator + : ConstStmtVisitor<EntityPointerLevelTranslator, + Expected<EntityPointerLevelSet>> { + friend class StmtVisitorBase; + + // Fallback method for all unsupported expression kind: + llvm::Error fallback(const Stmt *E) { + return llvm::createStringError( + "unsupported expression kind for translation to " + "EntityPointerLevel: %s", + E->getStmtClassName()); + } + + UnsafeBufferUsageTUSummaryBuilder &Builder; + + static EntityPointerLevel incrementPointerLevel(const EntityPointerLevel &E) { + return buildEntityPointerLevel(E.getEntity(), E.getPointerLevel() + 1); + } + + static EntityPointerLevel decrementPointerLevel(const EntityPointerLevel &E) { + assert(E.getPointerLevel() > 0); + return buildEntityPointerLevel(E.getEntity(), E.getPointerLevel() - 1); + } + + EntityPointerLevel createEntityPointerLevelFor(const EntityName &Name) { + return buildEntityPointerLevel(Builder.addEntity(Name), 1); + } + + // The common helper function for Translate(*base): + // Translate(*base) -> Translate(base) with .pointerLevel + 1 + Expected<EntityPointerLevelSet> translateDereferencePointer(const Expr *Ptr) { + assert(hasPointerType(Ptr)); + + Expected<EntityPointerLevelSet> SubResult = Visit(Ptr); + if (!SubResult) + return SubResult.takeError(); + + auto Incremented = llvm::map_range(*SubResult, incrementPointerLevel); + return EntityPointerLevelSet{Incremented.begin(), Incremented.end()}; + } + +public: + EntityPointerLevelTranslator(UnsafeBufferUsageTUSummaryBuilder &Builder) + : Builder(Builder) {} + + Expected<EntityPointerLevelSet> translate(const Expr *E) { return Visit(E); } + +private: + Expected<EntityPointerLevelSet> VisitStmt(const Stmt *E) { + return fallback(E); + } + + // Translate(base + x) -> Translate(base) + // Translate(x + base) -> Translate(base) + // Translate(base - x) -> Translate(base) + // Translate(base {+=, -=, =} x) -> Translate(base) + // Translate(x, base) -> Translate(base) + Expected<EntityPointerLevelSet> VisitBinaryOperator(const BinaryOperator *E) { + switch (E->getOpcode()) { + case clang::BO_Add: + if (hasPointerType(E->getLHS())) + return Visit(E->getLHS()); + return Visit(E->getRHS()); + case clang::BO_Sub: + case clang::BO_AddAssign: + case clang::BO_SubAssign: + case clang::BO_Assign: + return Visit(E->getLHS()); + case clang::BO_Comma: + return Visit(E->getRHS()); + default: + return fallback(E); + } + } + + // Translate({++, --}base) -> Translate(base) + // Translate(base{++, --}) -> Translate(base) + // Translate(*base) -> Translate(base) with .pointerLevel += 1 + // Translate(&base) -> {}, if Translate(base) is {} + // -> Translate(base) with .pointerLevel -= 1 + Expected<EntityPointerLevelSet> VisitUnaryOperator(const UnaryOperator *E) { + switch (E->getOpcode()) { + case clang::UO_PostInc: + case clang::UO_PostDec: + case clang::UO_PreInc: + case clang::UO_PreDec: + return Visit(E->getSubExpr()); + case clang::UO_AddrOf: { + Expected<EntityPointerLevelSet> SubResult = Visit(E->getSubExpr()); + if (!SubResult) + return SubResult.takeError(); + + auto Decremented = llvm::map_range(*SubResult, decrementPointerLevel); + return EntityPointerLevelSet{Decremented.begin(), Decremented.end()}; + } + case clang::UO_Deref: + return translateDereferencePointer(E->getSubExpr()); + default: + return fallback(E); + } + } + + // Translate((T*)base) -> Translate(p) if p has pointer type + // -> {} otherwise + Expected<EntityPointerLevelSet> VisitCastExpr(const CastExpr *E) { + if (hasPointerType(E->getSubExpr())) + return Visit(E->getSubExpr()); + return EntityPointerLevelSet{}; + } + + // Translate(f(...)) -> {} if it is an indirect call + // -> {(f_return, 1)}, otherwise + Expected<EntityPointerLevelSet> VisitCallExpr(const CallExpr *E) { + if (auto *FD = E->getDirectCallee()) + if (auto FDEntityName = getEntityNameForReturn(FD)) + return EntityPointerLevelSet{ + createEntityPointerLevelFor(*FDEntityName)}; + return EntityPointerLevelSet{}; + } + + // Translate(base[x]) -> Translate(*base) + Expected<EntityPointerLevelSet> + VisitArraySubscriptExpr(const ArraySubscriptExpr *E) { + // Translate(ptr[x]) := Translate(*ptr) ---------------- ziqingluo-90 wrote:
It means `ptr[x]` is processed the same as `*ptr` where the index `x` is irrelevant (abstracted away). https://github.com/llvm/llvm-project/pull/182941 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
