#pragma once #include #include #include #include #include "absl/strings/str_cat.h" #include "swift/extractor/trap/TrapDomain.h" #include "swift/extractor/infra/SwiftTagTraits.h" #include "swift/extractor/trap/generated/TrapClasses.h" #include "swift/extractor/infra/SwiftLocationExtractor.h" #include "swift/extractor/infra/SwiftBodyEmissionStrategy.h" #include "swift/extractor/infra/SwiftMangledName.h" #include "swift/extractor/config/SwiftExtractorState.h" #include "swift/logging/SwiftAssert.h" namespace codeql { // The main responsibilities of the SwiftDispatcher are as follows: // * redirect specific AST node emission to a corresponding visitor (statements, expressions, etc.) // * storing TRAP labels for emitted AST nodes (in the TrapLabelStore) to avoid re-emission // Since SwiftDispatcher sees all the AST nodes, it also attaches a location to every 'locatable' // node (AST nodes that are not types: declarations, statements, expressions, etc.). class SwiftDispatcher { // types to be supported by assignNewLabel/fetchLabel need to be listed here using Handle = std::variant; public: // all references and pointers passed as parameters to this constructor are supposed to outlive // the SwiftDispatcher SwiftDispatcher(const swift::SourceManager& sourceManager, SwiftExtractorState& state, TrapDomain& trap, SwiftLocationExtractor& locationExtractor, SwiftBodyEmissionStrategy& bodyEmissionStrategy) : sourceManager{sourceManager}, state{state}, trap{trap}, locationExtractor{locationExtractor}, bodyEmissionStrategy{bodyEmissionStrategy} {} const std::unordered_set getEncounteredModules() && { return std::move(encounteredModules); } template void emit(Entry&& entry) { bool valid = true; entry.forEachLabel([&valid, &entry, this](const char* field, int index, auto& label) { using Label = std::remove_reference_t; if (!label.valid()) { const char* action; if constexpr (std::derived_from) { action = "replacing with unspecified element"; label = emitUnspecified(idOf(entry), field, index); } else { action = "skipping emission"; valid = false; } LOG_ERROR("{} has undefined field {}{}, {}", entry.NAME, field, index >= 0 ? absl::StrCat("[", index, "]") : "", action); } }); if (valid) { trap.emit(entry, /* check */ false); } } template void emit(std::optional&& entry) { if (entry) { emit(std::move(*entry)); } } template void emit(std::variant&& entry) { std::visit([this](auto&& e) { this->emit(std::move(e)); }, std::move(entry)); } // This is a helper method to emit TRAP entries for AST nodes that we don't fully support yet. template void emitUnknown(E* entity) { auto label = assignNewLabel(entity); using Trap = BindingTrapOf; static_assert(sizeof(Trap) == sizeof(label), "Binding traps of unknown entities must only have the `id` field (the class " "should be empty in schema.yml)"); emit(Trap{label}); emit(ElementIsUnknownTrap{label}); } TrapLabel emitUnspecified(std::optional>&& parent, const char* property, int index) { UnspecifiedElement entry{trap.createTypedLabel()}; entry.error = "element was unspecified by the extractor"; entry.parent = std::move(parent); entry.property = property; if (index >= 0) { entry.index = index; } trap.emit(entry); return entry.id; } template std::optional> idOf(const E& entry) { if constexpr (requires { entry.id; }) { return entry.id; } else { return std::nullopt; } } // This method gives a TRAP label for already emitted AST node. // If the AST node was not emitted yet, then the emission is dispatched to a corresponding // visitor (see `visit(T *)` methods below). // clang-format off template requires std::constructible_from TrapLabelOf fetchLabel(const E* e, swift::Type type = {}) { // clang-format on if (!e) { // this will be treated on emission return undefined_label; } if constexpr (std::derived_from) { // canonicalize all VarDecls. For details, see doc of getCanonicalVarDecl if (auto var = llvm::dyn_cast(e)) { e = var->getCanonicalVarDecl(); } } auto& stored = store[e]; if (!stored.valid()) { auto inserted = fetching.insert(e); CODEQL_ASSERT(inserted.second, "detected infinite fetchLabel recursion"); stored = createLabel(e, type); fetching.erase(e); } return TrapLabelOf::unsafeCreateFromUntyped(stored); } // convenience `fetchLabel` overload for `swift::Type` (which is just a wrapper for // `swift::TypeBase*`) TrapLabel fetchLabel(swift::Type t) { return fetchLabel(t.getPointer()); } TrapLabel fetchLabel(swift::ASTNode node) { auto ret = fetchLabelFromUnion(node); if (!ret.valid()) { // TODO to be more useful, we need a generic way of attaching original source location info // to logs, this will come in upcoming work LOG_ERROR("Unable to fetch label for ASTNode"); } return ret; } // clang-format off template requires std::constructible_from TrapLabelOf fetchLabel(const E& e) { // clang-format on return fetchLabel(&e); } // convenience methods for structured C++ creation template auto createEntry(const E& e) { auto found = store.find(&e); CODEQL_ASSERT(found != store.end(), "createEntry called on non-fetched label"); using Tag = ConcreteTrapTagOf; auto label = TrapLabel::unsafeCreateFromUntyped(found->second); if constexpr (IsLocatable) { locationExtractor.attachLocation(sourceManager, e, label); } return TrapClassOf{label}; } // used to create a new entry for entities that should not be cached // an example is swift::Argument, that are created on the fly and thus have no stable pointer template auto createUncachedEntry(const E& e) { using Tag = TrapTagOf; auto label = trap.createTypedLabel(); locationExtractor.attachLocation(sourceManager, &e, label); return TrapClassOf{label}; } // return `std::optional(fetchLabel(arg))` if arg converts to true, otherwise std::nullopt // universal reference `Arg&&` is used to catch both temporary and non-const references, not // for perfect forwarding template auto fetchOptionalLabel(Arg&& arg, Args&&... args) -> std::optional { if (arg) { return fetchLabel(arg, std::forward(args)...); } return std::nullopt; } // map `fetchLabel` on the iterable `arg` // universal reference `Arg&&` is used to catch both temporary and non-const references, not // for perfect forwarding template auto fetchRepeatedLabels(Iterable&& arg) { using Label = decltype(fetchLabel(*arg.begin())); TrapLabelVectorWrapper ret; if constexpr (requires { arg.size(); }) { ret.data.reserve(arg.size()); } for (auto&& e : arg) { ret.data.push_back(fetchLabel(e)); } return ret; } template void emitDebugInfo(const Args&... args) { trap.debug(args...); } void emitComment(swift::Token& comment) { CommentsTrap entry{trap.createTypedLabel(), comment.getRawText().str()}; trap.emit(entry); locationExtractor.attachLocation(sourceManager, comment, entry.id); } protected: void visitPending() { while (!toBeVisited.empty()) { auto [next, type] = toBeVisited.back(); toBeVisited.pop_back(); // TODO: add tracing logs for visited stuff, maybe within the translators? std::visit([this, type = type](const auto* e) { visit(e, type); }, next); } } private: template UntypedTrapLabel createLabel(const E& e, swift::Type type) { if constexpr (requires { name(e); }) { if (auto mangledName = name(e)) { if (shouldVisit(e)) { toBeVisited.emplace_back(e, type); } return trap.createLabel(mangledName); } } // we always need to visit unnamed things toBeVisited.emplace_back(e, type); return trap.createLabel(); } template bool shouldVisit(const E& e) { if constexpr (std::convertible_to) { encounteredModules.insert(e->getModuleContext()); if (bodyEmissionStrategy.shouldEmitDeclBody(*e)) { extractedDeclaration(e); return true; } skippedDeclaration(e); return false; } return true; } void extractedDeclaration(const swift::Decl* decl) { if (isLazyDeclaration(decl)) { state.emittedDeclarations.insert(decl); } } void skippedDeclaration(const swift::Decl* decl) { if (isLazyDeclaration(decl)) { state.pendingDeclarations.insert(decl); } } static bool isLazyDeclaration(const swift::Decl* decl) { swift::ModuleDecl* module = decl->getModuleContext(); return module->isBuiltinModule() || module->getName().str() == "__ObjC" || module->isNonSwiftModule(); } template TrapLabel fetchLabelFromUnion(const llvm::PointerUnion u) { TrapLabel ret{}; // with logical op short-circuiting, this will stop trying on the first successful fetch bool unionCaseFound = (... || fetchLabelFromUnionCase(u, ret)); if (!unionCaseFound) { return undefined_label; } return ret; } template bool fetchLabelFromUnionCase(const llvm::PointerUnion u, TrapLabel& output) { // we rely on the fact that when we extract `ASTNode` instances (which only happens // on `BraceStmt`/`IfConfigDecl` elements), we cannot encounter a standalone `TypeRepr` there, // so we skip this case; extracting `TypeRepr`s here would be problematic as we would not be // able to provide the corresponding type if constexpr (!std::same_as) { if (auto e = u.template dyn_cast()) { output = fetchLabel(e); return true; } } return false; } virtual SwiftMangledName name(const swift::Decl* decl) = 0; virtual SwiftMangledName name(const swift::TypeBase* type) = 0; virtual void visit(const swift::Decl* decl) = 0; virtual void visit(const swift::Stmt* stmt) = 0; virtual void visit(const swift::StmtCondition* cond) = 0; virtual void visit(const swift::StmtConditionElement* cond) = 0; virtual void visit(const swift::PoundAvailableInfo* availability) = 0; virtual void visit(const swift::AvailabilitySpec* spec) = 0; virtual void visit(const swift::CaseLabelItem* item) = 0; virtual void visit(const swift::Expr* expr) = 0; virtual void visit(const swift::Pattern* pattern) = 0; virtual void visit(const swift::TypeRepr* typeRepr, swift::Type type) = 0; virtual void visit(const swift::TypeBase* type) = 0; virtual void visit(const swift::CapturedValue* capture) = 0; virtual void visit(const swift::MacroRoleAttr* attr) = 0; template requires(!std::derived_from) void visit(const T* e, swift::Type) { visit(e); } const swift::SourceManager& sourceManager; SwiftExtractorState& state; TrapDomain& trap; std::unordered_map store; std::unordered_set fetching; std::vector> toBeVisited; SwiftLocationExtractor& locationExtractor; SwiftBodyEmissionStrategy& bodyEmissionStrategy; std::unordered_set encounteredModules; Logger& logger() { static Logger ret{"dispatcher"}; return ret; } }; } // namespace codeql