//===- StdVariantChecker.cpp -------------------------------------*- C++ -*-==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "clang/AST/Type.h" #include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h" #include "clang/StaticAnalyzer/Core/BugReporter/BugType.h" #include "clang/StaticAnalyzer/Core/Checker.h" #include "clang/StaticAnalyzer/Core/CheckerManager.h" #include "clang/StaticAnalyzer/Core/PathSensitive/CallDescription.h" #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h" #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SVals.h" #include "llvm/ADT/FoldingSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include #include #include "TaggedUnionModeling.h" using namespace clang; using namespace ento; using namespace tagged_union_modeling; REGISTER_MAP_WITH_PROGRAMSTATE(VariantHeldTypeMap, const MemRegion *, QualType) namespace clang::ento::tagged_union_modeling { const CXXConstructorDecl * getConstructorDeclarationForCall(const CallEvent &Call) { const auto *ConstructorCall = dyn_cast(&Call); if (!ConstructorCall) return nullptr; return ConstructorCall->getDecl(); } bool isCopyConstructorCall(const CallEvent &Call) { if (const CXXConstructorDecl *ConstructorDecl = getConstructorDeclarationForCall(Call)) return ConstructorDecl->isCopyConstructor(); return false; } bool isCopyAssignmentCall(const CallEvent &Call) { const Decl *CopyAssignmentDecl = Call.getDecl(); if (const auto *AsMethodDecl = dyn_cast_or_null(CopyAssignmentDecl)) return AsMethodDecl->isCopyAssignmentOperator(); return false; } bool isMoveConstructorCall(const CallEvent &Call) { const CXXConstructorDecl *ConstructorDecl = getConstructorDeclarationForCall(Call); if (!ConstructorDecl) return false; return ConstructorDecl->isMoveConstructor(); } bool isMoveAssignmentCall(const CallEvent &Call) { const Decl *CopyAssignmentDecl = Call.getDecl(); const auto *AsMethodDecl = dyn_cast_or_null(CopyAssignmentDecl); if (!AsMethodDecl) return false; return AsMethodDecl->isMoveAssignmentOperator(); } bool isStdType(const Type *Type, llvm::StringRef TypeName) { auto *Decl = Type->getAsRecordDecl(); if (!Decl) return false; return (Decl->getName() == TypeName) && Decl->isInStdNamespace(); } bool isStdVariant(const Type *Type) { return isStdType(Type, llvm::StringLiteral("variant")); } } // end of namespace clang::ento::tagged_union_modeling static std::optional> getTemplateArgsFromVariant(const Type *VariantType) { const auto *TempSpecType = VariantType->getAs(); if (!TempSpecType) return {}; return TempSpecType->template_arguments(); } static std::optional getNthTemplateTypeArgFromVariant(const Type *varType, unsigned i) { std::optional> VariantTemplates = getTemplateArgsFromVariant(varType); if (!VariantTemplates) return {}; return (*VariantTemplates)[i].getAsType(); } static bool isVowel(char a) { switch (a) { case 'a': case 'e': case 'i': case 'o': case 'u': return true; default: return false; } } static llvm::StringRef indefiniteArticleBasedOnVowel(char a) { if (isVowel(a)) return "an"; return "a"; } class StdVariantChecker : public Checker { // Call descriptors to find relevant calls CallDescription VariantConstructor{{"std", "variant", "variant"}}; CallDescription VariantAssignmentOperator{{"std", "variant", "operator="}}; CallDescription StdGet{{"std", "get"}, 1, 1}; BugType BadVariantType{this, "BadVariantType", "BadVariantType"}; public: ProgramStateRef checkRegionChanges(ProgramStateRef State, const InvalidatedSymbols *, ArrayRef, ArrayRef Regions, const LocationContext *, const CallEvent *Call) const { if (!Call) return State; return removeInformationStoredForDeadInstances( *Call, State, Regions); } bool evalCall(const CallEvent &Call, CheckerContext &C) const { // Check if the call was not made from a system header. If it was then // we do an early return because it is part of the implementation. if (Call.isCalledFromSystemHeader()) return false; if (StdGet.matches(Call)) return handleStdGetCall(Call, C); // First check if a constructor call is happening. If it is a // constructor call, check if it is an std::variant constructor call. bool IsVariantConstructor = isa(Call) && VariantConstructor.matches(Call); bool IsVariantAssignmentOperatorCall = isa(Call) && VariantAssignmentOperator.matches(Call); if (IsVariantConstructor || IsVariantAssignmentOperatorCall) { if (Call.getNumArgs() == 0 && IsVariantConstructor) { handleDefaultConstructor(cast(&Call), C); return true; } // FIXME Later this checker should be extended to handle constructors // with multiple arguments. if (Call.getNumArgs() != 1) return false; SVal ThisSVal; if (IsVariantConstructor) { const auto &AsConstructorCall = cast(Call); ThisSVal = AsConstructorCall.getCXXThisVal(); } else if (IsVariantAssignmentOperatorCall) { const auto &AsMemberOpCall = cast(Call); ThisSVal = AsMemberOpCall.getCXXThisVal(); } else { return false; } handleConstructorAndAssignment(Call, C, ThisSVal); return true; } return false; } private: // The default constructed std::variant must be handled separately // by default the std::variant is going to hold a default constructed instance // of the first type of the possible types void handleDefaultConstructor(const CXXConstructorCall *ConstructorCall, CheckerContext &C) const { SVal ThisSVal = ConstructorCall->getCXXThisVal(); const auto *const ThisMemRegion = ThisSVal.getAsRegion(); if (!ThisMemRegion) return; std::optional DefaultType = getNthTemplateTypeArgFromVariant( ThisSVal.getType(C.getASTContext())->getPointeeType().getTypePtr(), 0); if (!DefaultType) return; ProgramStateRef State = ConstructorCall->getState(); State = State->set(ThisMemRegion, *DefaultType); C.addTransition(State); } bool handleStdGetCall(const CallEvent &Call, CheckerContext &C) const { ProgramStateRef State = Call.getState(); const auto &ArgType = Call.getArgSVal(0) .getType(C.getASTContext()) ->getPointeeType() .getTypePtr(); // We have to make sure that the argument is an std::variant. // There is another std::get with std::pair argument if (!isStdVariant(ArgType)) return false; // Get the mem region of the argument std::variant and look up the type // information that we know about it. const MemRegion *ArgMemRegion = Call.getArgSVal(0).getAsRegion(); const QualType *StoredType = State->get(ArgMemRegion); if (!StoredType) return false; const CallExpr *CE = cast(Call.getOriginExpr()); const FunctionDecl *FD = CE->getDirectCallee(); if (FD->getTemplateSpecializationArgs()->size() < 1) return false; const auto &TypeOut = FD->getTemplateSpecializationArgs()->asArray()[0]; // std::get's first template parameter can be the type we want to get // out of the std::variant or a natural number which is the position of // the requested type in the argument type list of the std::variant's // argument. QualType RetrievedType; switch (TypeOut.getKind()) { case TemplateArgument::ArgKind::Type: RetrievedType = TypeOut.getAsType(); break; case TemplateArgument::ArgKind::Integral: // In the natural number case we look up which type corresponds to the // number. if (std::optional NthTemplate = getNthTemplateTypeArgFromVariant( ArgType, TypeOut.getAsIntegral().getSExtValue())) { RetrievedType = *NthTemplate; break; } [[fallthrough]]; default: return false; } QualType RetrievedCanonicalType = RetrievedType.getCanonicalType(); QualType StoredCanonicalType = StoredType->getCanonicalType(); if (RetrievedCanonicalType == StoredCanonicalType) return true; ExplodedNode *ErrNode = C.generateNonFatalErrorNode(); if (!ErrNode) return false; llvm::SmallString<128> Str; llvm::raw_svector_ostream OS(Str); std::string StoredTypeName = StoredType->getAsString(); std::string RetrievedTypeName = RetrievedType.getAsString(); OS << "std::variant " << ArgMemRegion->getDescriptiveName() << " held " << indefiniteArticleBasedOnVowel(StoredTypeName[0]) << " \'" << StoredTypeName << "\', not " << indefiniteArticleBasedOnVowel(RetrievedTypeName[0]) << " \'" << RetrievedTypeName << "\'"; auto R = std::make_unique(BadVariantType, OS.str(), ErrNode); C.emitReport(std::move(R)); return true; } }; bool clang::ento::shouldRegisterStdVariantChecker( clang::ento::CheckerManager const &mgr) { return true; } void clang::ento::registerStdVariantChecker(clang::ento::CheckerManager &mgr) { mgr.registerChecker(); }