Merge branch 'main' into redsun82/rust-extract-libs

This commit is contained in:
Paolo Tranquilli
2025-06-02 15:33:53 +02:00
27 changed files with 2790 additions and 1204 deletions

View File

@@ -66,6 +66,6 @@ jobs:
# Update existing stubs in the repo with the freshly generated ones # Update existing stubs in the repo with the freshly generated ones
mv "$STUBS_PATH/output/stubs/_frameworks" ql/test/resources/stubs/ mv "$STUBS_PATH/output/stubs/_frameworks" ql/test/resources/stubs/
git status git status
codeql test run --threads=0 --search-path "${{ github.workspace }}" --check-databases --check-undefined-labels --check-repeated-labels --check-redefined-labels --consistency-queries ql/consistency-queries -- ql/test/library-tests/dataflow/flowsources/aspremote codeql test run --threads=0 --search-path "${{ github.workspace }}" --check-databases --check-diff-informed --check-undefined-labels --check-repeated-labels --check-redefined-labels --consistency-queries ql/consistency-queries -- ql/test/library-tests/dataflow/flowsources/aspremote
env: env:
GITHUB_TOKEN: ${{ github.token }} GITHUB_TOKEN: ${{ github.token }}

View File

@@ -35,6 +35,6 @@ jobs:
key: ruby-qltest key: ruby-qltest
- name: Run QL tests - name: Run QL tests
run: | run: |
codeql test run --dynamic-join-order-mode=all --threads=0 --ram 50000 --search-path "${{ github.workspace }}" --check-databases --check-undefined-labels --check-unused-labels --check-repeated-labels --check-redefined-labels --check-use-before-definition --consistency-queries ql/consistency-queries ql/test --compilation-cache "${{ steps.query-cache.outputs.cache-dir }}" codeql test run --dynamic-join-order-mode=all --threads=0 --ram 50000 --search-path "${{ github.workspace }}" --check-databases --check-diff-informed --check-undefined-labels --check-unused-labels --check-repeated-labels --check-redefined-labels --check-use-before-definition --consistency-queries ql/consistency-queries ql/test --compilation-cache "${{ steps.query-cache.outputs.cache-dir }}"
env: env:
GITHUB_TOKEN: ${{ github.token }} GITHUB_TOKEN: ${{ github.token }}

View File

@@ -68,6 +68,6 @@ jobs:
key: ruby-qltest key: ruby-qltest
- name: Run QL tests - name: Run QL tests
run: | run: |
codeql test run --threads=0 --ram 50000 --search-path "${{ github.workspace }}" --check-databases --check-undefined-labels --check-unused-labels --check-repeated-labels --check-redefined-labels --check-use-before-definition --consistency-queries ql/consistency-queries ql/test --compilation-cache "${{ steps.query-cache.outputs.cache-dir }}" codeql test run --threads=0 --ram 50000 --search-path "${{ github.workspace }}" --check-databases --check-diff-informed --check-undefined-labels --check-unused-labels --check-repeated-labels --check-redefined-labels --check-use-before-definition --consistency-queries ql/consistency-queries ql/test --compilation-cache "${{ steps.query-cache.outputs.cache-dir }}"
env: env:
GITHUB_TOKEN: ${{ github.token }} GITHUB_TOKEN: ${{ github.token }}

View File

@@ -0,0 +1,9 @@
{
"strategy": "dca",
"language": "cpp",
"targets": [
{ "name": "openssl", "with-sources": false, "with-sinks": false },
{ "name": "sqlite", "with-sources": false, "with-sinks": false }
],
"destination": "cpp/ql/lib/ext/generated"
}

View File

@@ -17,7 +17,7 @@ dependencies:
codeql/xml: ${workspace} codeql/xml: ${workspace}
dataExtensions: dataExtensions:
- ext/*.model.yml - ext/*.model.yml
- ext/generated/*.model.yml - ext/generated/**/*.model.yml
- ext/deallocation/*.model.yml - ext/deallocation/*.model.yml
- ext/allocation/*.model.yml - ext/allocation/*.model.yml
warnOnImplicitThis: true warnOnImplicitThis: true

View File

@@ -54,9 +54,9 @@ ql/lib/go.dbscheme.stats: ql/lib/go.dbscheme build/stats/src.stamp extractor
codeql dataset measure -o $@ build/stats/database/db-go codeql dataset measure -o $@ build/stats/database/db-go
test: all build/testdb/check-upgrade-path test: all build/testdb/check-upgrade-path
codeql test run -j0 ql/test --search-path .. --consistency-queries ql/test/consistency --compilation-cache=$(cache) --dynamic-join-order-mode=$(rtjo) --check-databases --fail-on-trap-errors --check-undefined-labels --check-unused-labels --check-repeated-labels --check-redefined-labels --check-use-before-definition codeql test run -j0 ql/test --search-path .. --check-diff-informed --consistency-queries ql/test/consistency --compilation-cache=$(cache) --dynamic-join-order-mode=$(rtjo) --check-databases --fail-on-trap-errors --check-undefined-labels --check-unused-labels --check-repeated-labels --check-redefined-labels --check-use-before-definition
# use GOOS=linux because GOOS=darwin GOARCH=386 is no longer supported # use GOOS=linux because GOOS=darwin GOARCH=386 is no longer supported
env GOOS=linux GOARCH=386 codeql$(EXE) test run -j0 ql/test/query-tests/Security/CWE-681 --search-path .. --consistency-queries ql/test/consistency --compilation-cache=$(cache) --dynamic-join-order-mode=$(rtjo) env GOOS=linux GOARCH=386 codeql$(EXE) test run -j0 ql/test/query-tests/Security/CWE-681 --search-path .. --check-diff-informed --consistency-queries ql/test/consistency --compilation-cache=$(cache) --dynamic-join-order-mode=$(rtjo)
cd extractor; $(BAZEL) test ... cd extractor; $(BAZEL) test ...
bash extractor-smoke-test/test.sh || (echo "Extractor smoke test FAILED"; exit 1) bash extractor-smoke-test/test.sh || (echo "Extractor smoke test FAILED"; exit 1)

View File

@@ -0,0 +1,500 @@
"""
Experimental script for bulk generation of MaD models based on a list of projects.
Note: This file must be formatted using the Black Python formatter.
"""
import os.path
import subprocess
import sys
from typing import NotRequired, TypedDict, List
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
import argparse
import json
import requests
import zipfile
import tarfile
from functools import cmp_to_key
import generate_mad as mad
gitroot = (
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
.decode("utf-8")
.strip()
)
build_dir = os.path.join(gitroot, "mad-generation-build")
# A project to generate models for
class Project(TypedDict):
"""
Type definition for projects (acquired via a GitHub repo) to model.
Attributes:
name: The name of the project
git_repo: URL to the git repository
git_tag: Optional Git tag to check out
"""
name: str
git_repo: NotRequired[str]
git_tag: NotRequired[str]
with_sinks: NotRequired[bool]
with_sinks: NotRequired[bool]
with_summaries: NotRequired[bool]
def should_generate_sinks(project: Project) -> bool:
return project.get("with-sinks", True)
def should_generate_sources(project: Project) -> bool:
return project.get("with-sources", True)
def should_generate_summaries(project: Project) -> bool:
return project.get("with-summaries", True)
def clone_project(project: Project) -> str:
"""
Shallow clone a project into the build directory.
Args:
project: A dictionary containing project information with 'name', 'git_repo', and optional 'git_tag' keys.
Returns:
The path to the cloned project directory.
"""
name = project["name"]
repo_url = project["git_repo"]
git_tag = project.get("git_tag")
# Determine target directory
target_dir = os.path.join(build_dir, name)
# Clone only if directory doesn't already exist
if not os.path.exists(target_dir):
if git_tag:
print(f"Cloning {name} from {repo_url} at tag {git_tag}")
else:
print(f"Cloning {name} from {repo_url}")
subprocess.check_call(
[
"git",
"clone",
"--quiet",
"--depth",
"1", # Shallow clone
*(
["--branch", git_tag] if git_tag else []
), # Add branch if tag is provided
repo_url,
target_dir,
]
)
print(f"Completed cloning {name}")
else:
print(f"Skipping cloning {name} as it already exists at {target_dir}")
return target_dir
def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
"""
Clone all projects in parallel.
Args:
projects: List of projects to clone
Returns:
List of (project, project_dir) pairs in the same order as the input projects
"""
start_time = time.time()
max_workers = min(8, len(projects)) # Use at most 8 threads
project_dirs_map = {} # Map to store results by project name
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Start cloning tasks and keep track of them
future_to_project = {
executor.submit(clone_project, project): project for project in projects
}
# Process results as they complete
for future in as_completed(future_to_project):
project = future_to_project[future]
try:
project_dir = future.result()
project_dirs_map[project["name"]] = (project, project_dir)
except Exception as e:
print(f"ERROR: Failed to clone {project['name']}: {e}")
if len(project_dirs_map) != len(projects):
failed_projects = [
project["name"]
for project in projects
if project["name"] not in project_dirs_map
]
print(
f"ERROR: Only {len(project_dirs_map)} out of {len(projects)} projects were cloned successfully. Failed projects: {', '.join(failed_projects)}"
)
sys.exit(1)
project_dirs = [project_dirs_map[project["name"]] for project in projects]
clone_time = time.time() - start_time
print(f"Cloning completed in {clone_time:.2f} seconds")
return project_dirs
def build_database(
language: str, extractor_options, project: Project, project_dir: str
) -> str | None:
"""
Build a CodeQL database for a project.
Args:
language: The language for which to build the database (e.g., "rust").
extractor_options: Additional options for the extractor.
project: A dictionary containing project information with 'name' and 'git_repo' keys.
project_dir: Path to the CodeQL database.
Returns:
The path to the created database directory.
"""
name = project["name"]
# Create database directory path
database_dir = os.path.join(build_dir, f"{name}-db")
# Only build the database if it doesn't already exist
if not os.path.exists(database_dir):
print(f"Building CodeQL database for {name}...")
extractor_options = [option for x in extractor_options for option in ("-O", x)]
try:
subprocess.check_call(
[
"codeql",
"database",
"create",
f"--language={language}",
"--source-root=" + project_dir,
"--overwrite",
*extractor_options,
"--",
database_dir,
]
)
print(f"Successfully created database at {database_dir}")
except subprocess.CalledProcessError as e:
print(f"Failed to create database for {name}: {e}")
return None
else:
print(
f"Skipping database creation for {name} as it already exists at {database_dir}"
)
return database_dir
def generate_models(config, project: Project, database_dir: str) -> None:
"""
Generate models for a project.
Args:
args: Command line arguments passed to this script.
name: The name of the project.
database_dir: Path to the CodeQL database.
"""
name = project["name"]
language = config["language"]
generator = mad.Generator(language)
# Note: The argument parser converts with-sinks to with_sinks, etc.
generator.generateSinks = should_generate_sinks(project)
generator.generateSources = should_generate_sources(project)
generator.generateSummaries = should_generate_summaries(project)
generator.setenvironment(database=database_dir, folder=name)
generator.run()
def build_databases_from_projects(
language: str, extractor_options, projects: List[Project]
) -> List[tuple[Project, str | None]]:
"""
Build databases for all projects in parallel.
Args:
language: The language for which to build the databases (e.g., "rust").
extractor_options: Additional options for the extractor.
projects: List of projects to build databases for.
Returns:
List of (project_name, database_dir) pairs, where database_dir is None if the build failed.
"""
# Clone projects in parallel
print("=== Cloning projects ===")
project_dirs = clone_projects(projects)
# Build databases for all projects
print("\n=== Building databases ===")
database_results = [
(
project,
build_database(language, extractor_options, project, project_dir),
)
for project, project_dir in project_dirs
]
return database_results
def get_json_from_github(
url: str, pat: str, extra_headers: dict[str, str] = {}
) -> dict:
"""
Download a JSON file from GitHub using a personal access token (PAT).
Args:
url: The URL to download the JSON file from.
pat: Personal Access Token for GitHub API authentication.
extra_headers: Additional headers to include in the request.
Returns:
The JSON response as a dictionary.
"""
headers = {"Authorization": f"token {pat}"} | extra_headers
response = requests.get(url, headers=headers)
if response.status_code != 200:
print(f"Failed to download JSON: {response.status_code} {response.text}")
sys.exit(1)
else:
return response.json()
def download_artifact(url: str, artifact_name: str, pat: str) -> str:
"""
Download a GitHub Actions artifact from a given URL.
Args:
url: The URL to download the artifact from.
artifact_name: The name of the artifact (used for naming the downloaded file).
pat: Personal Access Token for GitHub API authentication.
Returns:
The path to the downloaded artifact file.
"""
headers = {"Authorization": f"token {pat}", "Accept": "application/vnd.github+json"}
response = requests.get(url, stream=True, headers=headers)
zipName = artifact_name + ".zip"
if response.status_code != 200:
print(f"Failed to download file. Status code: {response.status_code}")
sys.exit(1)
target_zip = os.path.join(build_dir, zipName)
with open(target_zip, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"Download complete: {target_zip}")
return target_zip
def remove_extension(filename: str) -> str:
while "." in filename:
filename, _ = os.path.splitext(filename)
return filename
def pretty_name_from_artifact_name(artifact_name: str) -> str:
return artifact_name.split("___")[1]
def download_dca_databases(
experiment_name: str, pat: str, projects: List[Project]
) -> List[tuple[Project, str | None]]:
"""
Download databases from a DCA experiment.
Args:
experiment_name: The name of the DCA experiment to download databases from.
pat: Personal Access Token for GitHub API authentication.
projects: List of projects to download databases for.
Returns:
List of (project_name, database_dir) pairs, where database_dir is None if the download failed.
"""
database_results = {}
print("\n=== Finding projects ===")
response = get_json_from_github(
f"https://raw.githubusercontent.com/github/codeql-dca-main/data/{experiment_name}/reports/downloads.json",
pat,
)
targets = response["targets"]
project_map = {project["name"]: project for project in projects}
for data in targets.values():
downloads = data["downloads"]
analyzed_database = downloads["analyzed_database"]
artifact_name = analyzed_database["artifact_name"]
pretty_name = pretty_name_from_artifact_name(artifact_name)
if not pretty_name in project_map:
print(f"Skipping {pretty_name} as it is not in the list of projects")
continue
repository = analyzed_database["repository"]
run_id = analyzed_database["run_id"]
print(f"=== Finding artifact: {artifact_name} ===")
response = get_json_from_github(
f"https://api.github.com/repos/{repository}/actions/runs/{run_id}/artifacts",
pat,
{"Accept": "application/vnd.github+json"},
)
artifacts = response["artifacts"]
artifact_map = {artifact["name"]: artifact for artifact in artifacts}
print(f"=== Downloading artifact: {artifact_name} ===")
archive_download_url = artifact_map[artifact_name]["archive_download_url"]
artifact_zip_location = download_artifact(
archive_download_url, artifact_name, pat
)
print(f"=== Extracting artifact: {artifact_name} ===")
# The database is in a zip file, which contains a tar.gz file with the DB
# First we open the zip file
with zipfile.ZipFile(artifact_zip_location, "r") as zip_ref:
artifact_unzipped_location = os.path.join(build_dir, artifact_name)
# And then we extract it to build_dir/artifact_name
zip_ref.extractall(artifact_unzipped_location)
# And then we iterate over the contents of the extracted directory
# and extract the tar.gz files inside it
for entry in os.listdir(artifact_unzipped_location):
artifact_tar_location = os.path.join(artifact_unzipped_location, entry)
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
# And we just untar it to the same directory as the zip file
tar_ref.extractall(artifact_unzipped_location)
database_results[pretty_name] = os.path.join(
artifact_unzipped_location, remove_extension(entry)
)
print(f"\n=== Extracted {len(database_results)} databases ===")
return [(project, database_results[project["name"]]) for project in projects]
def get_mad_destination_for_project(config, name: str) -> str:
return os.path.join(config["destination"], name)
def get_strategy(config) -> str:
return config["strategy"].lower()
def main(config, args) -> None:
"""
Main function to handle the bulk generation of MaD models.
Args:
config: Configuration dictionary containing project details and other settings.
args: Command line arguments passed to this script.
"""
projects = config["targets"]
if not "language" in config:
print("ERROR: 'language' key is missing in the configuration file.")
sys.exit(1)
language = config["language"]
# Create build directory if it doesn't exist
if not os.path.exists(build_dir):
os.makedirs(build_dir)
# Check if any of the MaD directories contain working directory changes in git
for project in projects:
mad_dir = get_mad_destination_for_project(config, project["name"])
if os.path.exists(mad_dir):
git_status_output = subprocess.check_output(
["git", "status", "-s", mad_dir], text=True
).strip()
if git_status_output:
print(
f"""ERROR: Working directory changes detected in {mad_dir}.
Before generating new models, the existing models are deleted.
To avoid loss of data, please commit your changes."""
)
sys.exit(1)
database_results = []
match get_strategy(config):
case "repo":
extractor_options = config.get("extractor_options", [])
database_results = build_databases_from_projects(
language, extractor_options, projects
)
case "dca":
experiment_name = args.dca
if experiment_name is None:
print("ERROR: --dca argument is required for DCA strategy")
sys.exit(1)
if args.pat is None:
print("ERROR: --pat argument is required for DCA strategy")
sys.exit(1)
if not os.path.exists(args.pat):
print(f"ERROR: Personal Access Token file '{pat}' does not exist.")
sys.exit(1)
with open(args.pat, "r") as f:
pat = f.read().strip()
database_results = download_dca_databases(
experiment_name, pat, projects
)
# Generate models for all projects
print("\n=== Generating models ===")
failed_builds = [
project["name"] for project, db_dir in database_results if db_dir is None
]
if failed_builds:
print(
f"ERROR: {len(failed_builds)} database builds failed: {', '.join(failed_builds)}"
)
sys.exit(1)
# Delete the MaD directory for each project
for project, database_dir in database_results:
mad_dir = get_mad_destination_for_project(config, project["name"])
if os.path.exists(mad_dir):
print(f"Deleting existing MaD directory at {mad_dir}")
subprocess.check_call(["rm", "-rf", mad_dir])
for project, database_dir in database_results:
if database_dir is not None:
generate_models(config, project, database_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", type=str, help="Path to the configuration file.", required=True
)
parser.add_argument(
"--dca",
type=str,
help="Name of a DCA run that built all the projects",
required=False,
)
parser.add_argument(
"--pat",
type=str,
help="Path to a file containing the PAT token required to grab DCA databases (the same as the one you use for DCA)",
required=False,
)
args = parser.parse_args()
# Load config file
config = {}
if not os.path.exists(args.config):
print(f"ERROR: Config file '{args.config}' does not exist.")
sys.exit(1)
try:
with open(args.config, "r") as f:
config = json.load(f)
except json.JSONDecodeError as e:
print(f"ERROR: Failed to parse JSON file {args.config}: {e}")
sys.exit(1)
main(config, args)

View File

@@ -1,335 +0,0 @@
"""
Experimental script for bulk generation of MaD models based on a list of projects.
Currently the script only targets Rust.
"""
import os.path
import subprocess
import sys
from typing import NotRequired, TypedDict, List
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
import generate_mad as mad
gitroot = (
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
.decode("utf-8")
.strip()
)
build_dir = os.path.join(gitroot, "mad-generation-build")
def path_to_mad_directory(language: str, name: str) -> str:
return os.path.join(gitroot, f"{language}/ql/lib/ext/generated/{name}")
# A project to generate models for
class Project(TypedDict):
"""
Type definition for Rust projects to model.
Attributes:
name: The name of the project
git_repo: URL to the git repository
git_tag: Optional Git tag to check out
"""
name: str
git_repo: str
git_tag: NotRequired[str]
# List of Rust projects to generate models for.
projects: List[Project] = [
{
"name": "libc",
"git_repo": "https://github.com/rust-lang/libc",
"git_tag": "0.2.172",
},
{
"name": "log",
"git_repo": "https://github.com/rust-lang/log",
"git_tag": "0.4.27",
},
{
"name": "memchr",
"git_repo": "https://github.com/BurntSushi/memchr",
"git_tag": "2.7.4",
},
{
"name": "once_cell",
"git_repo": "https://github.com/matklad/once_cell",
"git_tag": "v1.21.3",
},
{
"name": "rand",
"git_repo": "https://github.com/rust-random/rand",
"git_tag": "0.9.1",
},
{
"name": "smallvec",
"git_repo": "https://github.com/servo/rust-smallvec",
"git_tag": "v1.15.0",
},
{
"name": "serde",
"git_repo": "https://github.com/serde-rs/serde",
"git_tag": "v1.0.219",
},
{
"name": "tokio",
"git_repo": "https://github.com/tokio-rs/tokio",
"git_tag": "tokio-1.45.0",
},
{
"name": "reqwest",
"git_repo": "https://github.com/seanmonstar/reqwest",
"git_tag": "v0.12.15",
},
{
"name": "rocket",
"git_repo": "https://github.com/SergioBenitez/Rocket",
"git_tag": "v0.5.1",
},
{
"name": "actix-web",
"git_repo": "https://github.com/actix/actix-web",
"git_tag": "web-v4.11.0",
},
{
"name": "hyper",
"git_repo": "https://github.com/hyperium/hyper",
"git_tag": "v1.6.0",
},
{
"name": "clap",
"git_repo": "https://github.com/clap-rs/clap",
"git_tag": "v4.5.38",
},
]
def clone_project(project: Project) -> str:
"""
Shallow clone a project into the build directory.
Args:
project: A dictionary containing project information with 'name', 'git_repo', and optional 'git_tag' keys.
Returns:
The path to the cloned project directory.
"""
name = project["name"]
repo_url = project["git_repo"]
git_tag = project.get("git_tag")
# Determine target directory
target_dir = os.path.join(build_dir, name)
# Clone only if directory doesn't already exist
if not os.path.exists(target_dir):
if git_tag:
print(f"Cloning {name} from {repo_url} at tag {git_tag}")
else:
print(f"Cloning {name} from {repo_url}")
subprocess.check_call(
[
"git",
"clone",
"--quiet",
"--depth",
"1", # Shallow clone
*(
["--branch", git_tag] if git_tag else []
), # Add branch if tag is provided
repo_url,
target_dir,
]
)
print(f"Completed cloning {name}")
else:
print(f"Skipping cloning {name} as it already exists at {target_dir}")
return target_dir
def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
"""
Clone all projects in parallel.
Args:
projects: List of projects to clone
Returns:
List of (project, project_dir) pairs in the same order as the input projects
"""
start_time = time.time()
max_workers = min(8, len(projects)) # Use at most 8 threads
project_dirs_map = {} # Map to store results by project name
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Start cloning tasks and keep track of them
future_to_project = {
executor.submit(clone_project, project): project for project in projects
}
# Process results as they complete
for future in as_completed(future_to_project):
project = future_to_project[future]
try:
project_dir = future.result()
project_dirs_map[project["name"]] = (project, project_dir)
except Exception as e:
print(f"ERROR: Failed to clone {project['name']}: {e}")
if len(project_dirs_map) != len(projects):
failed_projects = [
project["name"]
for project in projects
if project["name"] not in project_dirs_map
]
print(
f"ERROR: Only {len(project_dirs_map)} out of {len(projects)} projects were cloned successfully. Failed projects: {', '.join(failed_projects)}"
)
sys.exit(1)
project_dirs = [project_dirs_map[project["name"]] for project in projects]
clone_time = time.time() - start_time
print(f"Cloning completed in {clone_time:.2f} seconds")
return project_dirs
def build_database(project: Project, project_dir: str) -> str | None:
"""
Build a CodeQL database for a project.
Args:
project: A dictionary containing project information with 'name' and 'git_repo' keys.
project_dir: The directory containing the project source code.
Returns:
The path to the created database directory.
"""
name = project["name"]
# Create database directory path
database_dir = os.path.join(build_dir, f"{name}-db")
# Only build the database if it doesn't already exist
if not os.path.exists(database_dir):
print(f"Building CodeQL database for {name}...")
try:
subprocess.check_call(
[
"codeql",
"database",
"create",
"--language=rust",
"--source-root=" + project_dir,
"--overwrite",
"-O",
"cargo_features='*'",
"--",
database_dir,
]
)
print(f"Successfully created database at {database_dir}")
except subprocess.CalledProcessError as e:
print(f"Failed to create database for {name}: {e}")
return None
else:
print(
f"Skipping database creation for {name} as it already exists at {database_dir}"
)
return database_dir
def generate_models(project: Project, database_dir: str) -> None:
"""
Generate models for a project.
Args:
project: A dictionary containing project information with 'name' and 'git_repo' keys.
project_dir: The directory containing the project source code.
"""
name = project["name"]
generator = mad.Generator("rust")
generator.generateSinks = True
generator.generateSources = True
generator.generateSummaries = True
generator.setenvironment(database=database_dir, folder=name)
generator.run()
def main() -> None:
"""
Process all projects in three distinct phases:
1. Clone projects (in parallel)
2. Build databases for projects
3. Generate models for successful database builds
"""
# Create build directory if it doesn't exist
if not os.path.exists(build_dir):
os.makedirs(build_dir)
# Check if any of the MaD directories contain working directory changes in git
for project in projects:
mad_dir = path_to_mad_directory("rust", project["name"])
if os.path.exists(mad_dir):
git_status_output = subprocess.check_output(
["git", "status", "-s", mad_dir], text=True
).strip()
if git_status_output:
print(
f"""ERROR: Working directory changes detected in {mad_dir}.
Before generating new models, the existing models are deleted.
To avoid loss of data, please commit your changes."""
)
sys.exit(1)
# Phase 1: Clone projects in parallel
print("=== Phase 1: Cloning projects ===")
project_dirs = clone_projects(projects)
# Phase 2: Build databases for all projects
print("\n=== Phase 2: Building databases ===")
database_results = [
(project, build_database(project, project_dir))
for project, project_dir in project_dirs
]
# Phase 3: Generate models for all projects
print("\n=== Phase 3: Generating models ===")
failed_builds = [
project["name"] for project, db_dir in database_results if db_dir is None
]
if failed_builds:
print(
f"ERROR: {len(failed_builds)} database builds failed: {', '.join(failed_builds)}"
)
sys.exit(1)
# Delete the MaD directory for each project
for project, database_dir in database_results:
mad_dir = path_to_mad_directory("rust", project["name"])
if os.path.exists(mad_dir):
print(f"Deleting existing MaD directory at {mad_dir}")
subprocess.check_call(["rm", "-rf", mad_dir])
for project, database_dir in database_results:
if database_dir is not None:
generate_models(project, database_dir)
if __name__ == "__main__":
main()

View File

@@ -65,4 +65,4 @@ extractor: $(FILES) $(BIN_FILES)
cp ../target/release/codeql-extractor-ruby$(EXE) extractor-pack/tools/$(CODEQL_PLATFORM)/extractor$(EXE) cp ../target/release/codeql-extractor-ruby$(EXE) extractor-pack/tools/$(CODEQL_PLATFORM)/extractor$(EXE)
test: extractor dbscheme test: extractor dbscheme
codeql test run --check-databases --check-unused-labels --check-repeated-labels --check-redefined-labels --check-use-before-definition --search-path .. --consistency-queries ql/consistency-queries ql/test codeql test run --check-databases --check-diff-informed --check-unused-labels --check-repeated-labels --check-redefined-labels --check-use-before-definition --search-path .. --consistency-queries ql/consistency-queries ql/test

View File

@@ -255,9 +255,10 @@ fn get_additional_fields(node: &AstNodeSrc) -> Vec<FieldInfo> {
"WhileExpr" => vec![FieldInfo::optional("condition", "Expr")], "WhileExpr" => vec![FieldInfo::optional("condition", "Expr")],
"MatchGuard" => vec![FieldInfo::optional("condition", "Expr")], "MatchGuard" => vec![FieldInfo::optional("condition", "Expr")],
"MacroDef" => vec![ "MacroDef" => vec![
FieldInfo::optional("args", "TokenTree"), FieldInfo::body("args", "TokenTree"),
FieldInfo::optional("body", "TokenTree"), FieldInfo::body("body", "TokenTree"),
], ],
"MacroCall" => vec![FieldInfo::body("token_tree", "TokenTree")],
"FormatArgsExpr" => vec![FieldInfo::list("args", "FormatArgsArg")], "FormatArgsExpr" => vec![FieldInfo::list("args", "FormatArgsArg")],
"ArgList" => vec![FieldInfo::list("args", "Expr")], "ArgList" => vec![FieldInfo::list("args", "Expr")],
"Fn" => vec![FieldInfo::body("body", "BlockExpr")], "Fn" => vec![FieldInfo::body("body", "BlockExpr")],
@@ -295,7 +296,7 @@ fn get_fields(node: &AstNodeSrc) -> Vec<FieldInfo> {
match (node.name.as_str(), name.as_str()) { match (node.name.as_str(), name.as_str()) {
("ArrayExpr", "expr") // The ArrayExpr type also has an 'exprs' field ("ArrayExpr", "expr") // The ArrayExpr type also has an 'exprs' field
| ("PathSegment", "ty" | "path_type") // these are broken, handling them manually | ("PathSegment", "ty" | "path_type") // these are broken, handling them manually
| ("Param", "pat") // handled manually to use `body` | ("Param", "pat") | ("MacroCall", "token_tree") // handled manually to use `body`
=> continue, => continue,
_ => {} _ => {}
} }

View File

@@ -45,9 +45,10 @@ options:
cargo_features: cargo_features:
title: Cargo features to turn on title: Cargo features to turn on
description: > description: >
Comma-separated list of features to turn on. If any value is `*` all features Comma-separated list of features to turn on. By default all features are enabled.
are turned on. By default only default cargo features are enabled. Can be If any features are specified, then only those features are enabled. The `default`
repeated. feature must be explicitly specified if only default features are desired.
Can be repeated.
type: array type: array
cargo_cfg_overrides: cargo_cfg_overrides:
title: Cargo cfg overrides title: Cargo cfg overrides

View File

@@ -129,6 +129,23 @@ impl Config {
} }
} }
fn cargo_features(&self) -> CargoFeatures {
// '*' is to be considered deprecated but still kept in for backward compatibility
if self.cargo_features.is_empty() || self.cargo_features.iter().any(|f| f == "*") {
CargoFeatures::All
} else {
CargoFeatures::Selected {
features: self
.cargo_features
.iter()
.filter(|f| *f != "default")
.cloned()
.collect(),
no_default_features: !self.cargo_features.iter().any(|f| f == "default"),
}
}
}
pub fn to_cargo_config(&self, dir: &AbsPath) -> (CargoConfig, LoadCargoConfig) { pub fn to_cargo_config(&self, dir: &AbsPath) -> (CargoConfig, LoadCargoConfig) {
let sysroot = self.sysroot(dir); let sysroot = self.sysroot(dir);
( (
@@ -159,16 +176,7 @@ impl Config {
.unwrap_or_else(|| self.scratch_dir.join("target")), .unwrap_or_else(|| self.scratch_dir.join("target")),
) )
.ok(), .ok(),
features: if self.cargo_features.is_empty() { features: self.cargo_features(),
Default::default()
} else if self.cargo_features.contains(&"*".to_string()) {
CargoFeatures::All
} else {
CargoFeatures::Selected {
features: self.cargo_features.clone(),
no_default_features: false,
}
},
target: self.cargo_target.clone(), target: self.cargo_target.clone(),
cfg_overrides: to_cfg_overrides(&self.cargo_cfg_overrides), cfg_overrides: to_cfg_overrides(&self.cargo_cfg_overrides),
wrap_rustc_in_build_scripts: false, wrap_rustc_in_build_scripts: false,

View File

@@ -317,7 +317,6 @@ fn main() -> anyhow::Result<()> {
.source_root(db) .source_root(db)
.is_library .is_library
{ {
tracing::info!("file: {}", file.display());
extractor.extract_with_semantics( extractor.extract_with_semantics(
file, file,
&semantics, &semantics,

View File

@@ -25,7 +25,9 @@ use ra_ap_syntax::{
#[macro_export] #[macro_export]
macro_rules! pre_emit { macro_rules! pre_emit {
(Item, $self:ident, $node:ident) => { (Item, $self:ident, $node:ident) => {
$self.setup_item_expansion($node); if let Some(label) = $self.prepare_item_expansion($node) {
return Some(label);
}
}; };
($($_:tt)*) => {}; ($($_:tt)*) => {};
} }
@@ -688,52 +690,75 @@ impl<'a> Translator<'a> {
} }
} }
pub(crate) fn setup_item_expansion(&mut self, node: &ast::Item) { pub(crate) fn prepare_item_expansion(
if self.semantics.is_some_and(|s| { &mut self,
let file = s.hir_file_for(node.syntax()); node: &ast::Item,
let node = InFile::new(file, node); ) -> Option<Label<generated::Item>> {
s.is_attr_macro_call(node) if self.source_kind == SourceKind::Library {
}) { // if the item expands via an attribute macro, we want to only emit the expansion
if let Some(expanded) = self.emit_attribute_macro_expansion(node) {
// we wrap it in a dummy MacroCall to get a single Item label that can replace
// the original Item
let label = self.trap.emit(generated::MacroCall {
id: TrapId::Star,
attrs: vec![],
path: None,
token_tree: None,
});
generated::MacroCall::emit_macro_call_expansion(
label,
expanded.into(),
&mut self.trap.writer,
);
return Some(label.into());
}
}
let semantics = self.semantics.as_ref()?;
let file = semantics.hir_file_for(node.syntax());
let node = InFile::new(file, node);
if semantics.is_attr_macro_call(node) {
self.macro_context_depth += 1; self.macro_context_depth += 1;
} }
None
}
fn emit_attribute_macro_expansion(
&mut self,
node: &ast::Item,
) -> Option<Label<generated::MacroItems>> {
let semantics = self.semantics?;
let file = semantics.hir_file_for(node.syntax());
let infile_node = InFile::new(file, node);
if !semantics.is_attr_macro_call(infile_node) {
return None;
}
self.macro_context_depth -= 1;
if self.macro_context_depth > 0 {
// only expand the outermost attribute macro
return None;
}
let ExpandResult {
value: expanded, ..
} = semantics.expand_attr_macro(node)?;
self.emit_macro_expansion_parse_errors(node, &expanded);
let macro_items = ast::MacroItems::cast(expanded).or_else(|| {
let message = "attribute macro expansion cannot be cast to MacroItems".to_owned();
let location = self.location_for_node(node);
self.emit_diagnostic(
DiagnosticSeverity::Warning,
"item_expansion".to_owned(),
message.clone(),
message,
location.unwrap_or(UNKNOWN_LOCATION),
);
None
})?;
self.emit_macro_items(&macro_items)
} }
pub(crate) fn emit_item_expansion(&mut self, node: &ast::Item, label: Label<generated::Item>) { pub(crate) fn emit_item_expansion(&mut self, node: &ast::Item, label: Label<generated::Item>) {
// TODO: remove this after fixing exponential expansion on libraries like funty-2.0.0 if let Some(expanded) = self.emit_attribute_macro_expansion(node) {
if self.source_kind == SourceKind::Library {
return;
}
(|| {
let semantics = self.semantics?;
let file = semantics.hir_file_for(node.syntax());
let infile_node = InFile::new(file, node);
if !semantics.is_attr_macro_call(infile_node) {
return None;
}
self.macro_context_depth -= 1;
if self.macro_context_depth > 0 {
// only expand the outermost attribute macro
return None;
}
let ExpandResult {
value: expanded, ..
} = semantics.expand_attr_macro(node)?;
self.emit_macro_expansion_parse_errors(node, &expanded);
let macro_items = ast::MacroItems::cast(expanded).or_else(|| {
let message = "attribute macro expansion cannot be cast to MacroItems".to_owned();
let location = self.location_for_node(node);
self.emit_diagnostic(
DiagnosticSeverity::Warning,
"item_expansion".to_owned(),
message.clone(),
message,
location.unwrap_or(UNKNOWN_LOCATION),
);
None
})?;
let expanded = self.emit_macro_items(&macro_items)?;
generated::Item::emit_attribute_macro_expansion(label, expanded, &mut self.trap.writer); generated::Item::emit_attribute_macro_expansion(label, expanded, &mut self.trap.writer);
Some(()) }
})();
} }
} }

View File

@@ -1627,7 +1627,11 @@ impl Translator<'_> {
} }
let attrs = node.attrs().filter_map(|x| self.emit_attr(&x)).collect(); let attrs = node.attrs().filter_map(|x| self.emit_attr(&x)).collect();
let path = node.path().and_then(|x| self.emit_path(&x)); let path = node.path().and_then(|x| self.emit_path(&x));
let token_tree = node.token_tree().and_then(|x| self.emit_token_tree(&x)); let token_tree = if self.should_skip_bodies() {
None
} else {
node.token_tree().and_then(|x| self.emit_token_tree(&x))
};
let label = self.trap.emit(generated::MacroCall { let label = self.trap.emit(generated::MacroCall {
id: TrapId::Star, id: TrapId::Star,
attrs, attrs,
@@ -1647,9 +1651,17 @@ impl Translator<'_> {
if self.should_be_excluded(node) { if self.should_be_excluded(node) {
return None; return None;
} }
let args = node.args().and_then(|x| self.emit_token_tree(&x)); let args = if self.should_skip_bodies() {
None
} else {
node.args().and_then(|x| self.emit_token_tree(&x))
};
let attrs = node.attrs().filter_map(|x| self.emit_attr(&x)).collect(); let attrs = node.attrs().filter_map(|x| self.emit_attr(&x)).collect();
let body = node.body().and_then(|x| self.emit_token_tree(&x)); let body = if self.should_skip_bodies() {
None
} else {
node.body().and_then(|x| self.emit_token_tree(&x))
};
let name = node.name().and_then(|x| self.emit_name(&x)); let name = node.name().and_then(|x| self.emit_name(&x));
let visibility = node.visibility().and_then(|x| self.emit_visibility(&x)); let visibility = node.visibility().and_then(|x| self.emit_visibility(&x));
let label = self.trap.emit(generated::MacroDef { let label = self.trap.emit(generated::MacroDef {

View File

@@ -0,0 +1,75 @@
{
"strategy": "repo",
"language": "rust",
"targets": [
{
"name": "libc",
"git_repo": "https://github.com/rust-lang/libc",
"git_tag": "0.2.172"
},
{
"name": "log",
"git_repo": "https://github.com/rust-lang/log",
"git_tag": "0.4.27"
},
{
"name": "memchr",
"git_repo": "https://github.com/BurntSushi/memchr",
"git_tag": "2.7.4"
},
{
"name": "once_cell",
"git_repo": "https://github.com/matklad/once_cell",
"git_tag": "v1.21.3"
},
{
"name": "rand",
"git_repo": "https://github.com/rust-random/rand",
"git_tag": "0.9.1"
},
{
"name": "smallvec",
"git_repo": "https://github.com/servo/rust-smallvec",
"git_tag": "v1.15.0"
},
{
"name": "serde",
"git_repo": "https://github.com/serde-rs/serde",
"git_tag": "v1.0.219"
},
{
"name": "tokio",
"git_repo": "https://github.com/tokio-rs/tokio",
"git_tag": "tokio-1.45.0"
},
{
"name": "reqwest",
"git_repo": "https://github.com/seanmonstar/reqwest",
"git_tag": "v0.12.15"
},
{
"name": "rocket",
"git_repo": "https://github.com/SergioBenitez/Rocket",
"git_tag": "v0.5.1"
},
{
"name": "actix-web",
"git_repo": "https://github.com/actix/actix-web",
"git_tag": "web-v4.11.0"
},
{
"name": "hyper",
"git_repo": "https://github.com/hyperium/hyper",
"git_tag": "v1.6.0"
},
{
"name": "clap",
"git_repo": "https://github.com/clap-rs/clap",
"git_tag": "v4.5.38"
}
],
"destination": "rust/ql/lib/ext/generated",
"extractor_options": [
"cargo_features='*'"
]
}

View File

@@ -1,5 +1,6 @@
import pytest import pytest
@pytest.mark.ql_test(expected=".all.expected")
def test_default(codeql, rust): def test_default(codeql, rust):
codeql.database.create() codeql.database.create()
@@ -8,10 +9,33 @@ def test_default(codeql, rust):
pytest.param(p, pytest.param(p,
marks=pytest.mark.ql_test(expected=f".{e}.expected")) marks=pytest.mark.ql_test(expected=f".{e}.expected"))
for p, e in ( for p, e in (
("default", "none"),
("foo", "foo"), ("foo", "foo"),
("bar", "bar"), ("bar", "bar"),
("*", "all"), ("*", "all"),
("foo,bar", "all")) ("foo,bar", "all"),
("default,foo", "foo"),
("default,bar", "bar"),
)
]) ])
def test_features(codeql, rust, features): def test_features(codeql, rust, features):
codeql.database.create(extractor_option=f"cargo_features={features}") codeql.database.create(extractor_option=f"cargo_features={features}")
@pytest.mark.parametrize("features",
[
pytest.param(p,
marks=pytest.mark.ql_test(expected=f".{e}.expected"))
for p, e in (
("default", "foo"),
("foo", "foo"),
("bar", "bar"),
("*", "all"),
("foo,bar", "all"),
("default,foo", "foo"),
("default,bar", "all"),
)
])
def test_features_with_default(codeql, rust, features):
with open("Cargo.toml", "a") as f:
print('default = ["foo"]', file=f)
codeql.database.create(extractor_option=f"cargo_features={features}")

View File

@@ -28,6 +28,10 @@ module Impl {
override string getOperatorName() { result = Generated::BinaryExpr.super.getOperatorName() } override string getOperatorName() { result = Generated::BinaryExpr.super.getOperatorName() }
override Expr getAnOperand() { result = [this.getLhs(), this.getRhs()] } override Expr getOperand(int n) {
n = 0 and result = this.getLhs()
or
n = 1 and result = this.getRhs()
}
} }
} }

View File

@@ -7,6 +7,78 @@
private import rust private import rust
private import codeql.rust.elements.internal.ExprImpl::Impl as ExprImpl private import codeql.rust.elements.internal.ExprImpl::Impl as ExprImpl
/**
* Holds if the operator `op` is overloaded to a trait with the canonical path
* `path` and the method name `method`.
*/
private predicate isOverloaded(string op, string path, string method) {
// Negation
op = "-" and path = "core::ops::arith::Neg" and method = "neg"
or
// Not
op = "!" and path = "core::ops::bit::Not" and method = "not"
or
// Dereference
op = "*" and path = "core::ops::Deref" and method = "deref"
or
// Comparison operators
op = "==" and path = "core::cmp::PartialEq" and method = "eq"
or
op = "!=" and path = "core::cmp::PartialEq" and method = "ne"
or
op = "<" and path = "core::cmp::PartialOrd" and method = "lt"
or
op = "<=" and path = "core::cmp::PartialOrd" and method = "le"
or
op = ">" and path = "core::cmp::PartialOrd" and method = "gt"
or
op = ">=" and path = "core::cmp::PartialOrd" and method = "ge"
or
// Arithmetic operators
op = "+" and path = "core::ops::arith::Add" and method = "add"
or
op = "-" and path = "core::ops::arith::Sub" and method = "sub"
or
op = "*" and path = "core::ops::arith::Mul" and method = "mul"
or
op = "/" and path = "core::ops::arith::Div" and method = "div"
or
op = "%" and path = "core::ops::arith::Rem" and method = "rem"
or
// Arithmetic assignment expressions
op = "+=" and path = "core::ops::arith::AddAssign" and method = "add_assign"
or
op = "-=" and path = "core::ops::arith::SubAssign" and method = "sub_assign"
or
op = "*=" and path = "core::ops::arith::MulAssign" and method = "mul_assign"
or
op = "/=" and path = "core::ops::arith::DivAssign" and method = "div_assign"
or
op = "%=" and path = "core::ops::arith::RemAssign" and method = "rem_assign"
or
// Bitwise operators
op = "&" and path = "core::ops::bit::BitAnd" and method = "bitand"
or
op = "|" and path = "core::ops::bit::BitOr" and method = "bitor"
or
op = "^" and path = "core::ops::bit::BitXor" and method = "bitxor"
or
op = "<<" and path = "core::ops::bit::Shl" and method = "shl"
or
op = ">>" and path = "core::ops::bit::Shr" and method = "shr"
or
// Bitwise assignment operators
op = "&=" and path = "core::ops::bit::BitAndAssign" and method = "bitand_assign"
or
op = "|=" and path = "core::ops::bit::BitOrAssign" and method = "bitor_assign"
or
op = "^=" and path = "core::ops::bit::BitXorAssign" and method = "bitxor_assign"
or
op = "<<=" and path = "core::ops::bit::ShlAssign" and method = "shl_assign"
or
op = ">>=" and path = "core::ops::bit::ShrAssign" and method = "shr_assign"
}
/** /**
* INTERNAL: This module contains the customizable definition of `Operation` and should not * INTERNAL: This module contains the customizable definition of `Operation` and should not
* be referenced directly. * be referenced directly.
@@ -16,14 +88,28 @@ module Impl {
* An operation, for example `&&`, `+=`, `!` or `*`. * An operation, for example `&&`, `+=`, `!` or `*`.
*/ */
abstract class Operation extends ExprImpl::Expr { abstract class Operation extends ExprImpl::Expr {
/** /** Gets the operator name of this operation, if it exists. */
* Gets the operator name of this operation, if it exists.
*/
abstract string getOperatorName(); abstract string getOperatorName();
/** Gets the `n`th operand of this operation, if any. */
abstract Expr getOperand(int n);
/** /**
* Gets an operand of this operation. * Gets the number of operands of this operation.
*
* This is either 1 for prefix operations, or 2 for binary operations.
*/ */
abstract Expr getAnOperand(); final int getNumberOfOperands() { result = strictcount(this.getAnOperand()) }
/** Gets an operand of this operation. */
Expr getAnOperand() { result = this.getOperand(_) }
/**
* Holds if this operation is overloaded to the method `methodName` of the
* trait `trait`.
*/
predicate isOverloaded(Trait trait, string methodName) {
isOverloaded(this.getOperatorName(), trait.getCanonicalPath(), methodName)
}
} }
} }

View File

@@ -26,6 +26,6 @@ module Impl {
override string getOperatorName() { result = Generated::PrefixExpr.super.getOperatorName() } override string getOperatorName() { result = Generated::PrefixExpr.super.getOperatorName() }
override Expr getAnOperand() { result = this.getExpr() } override Expr getOperand(int n) { n = 0 and result = this.getExpr() }
} }
} }

View File

@@ -29,7 +29,7 @@ module Impl {
override string getOperatorName() { result = "&" } override string getOperatorName() { result = "&" }
override Expr getAnOperand() { result = this.getExpr() } override Expr getOperand(int n) { n = 0 and result = this.getExpr() }
private string getSpecPart(int index) { private string getSpecPart(int index) {
index = 0 and this.isRaw() and result = "raw" index = 0 and this.isRaw() and result = "raw"

View File

@@ -643,12 +643,22 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl
class Access extends CallExprBase { abstract class Access extends Expr {
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
abstract AstNode getNodeAt(AccessPosition apos);
abstract Type getInferredType(AccessPosition apos, TypePath path);
abstract Declaration getTarget();
}
private class CallExprBaseAccess extends Access instanceof CallExprBase {
private TypeMention getMethodTypeArg(int i) { private TypeMention getMethodTypeArg(int i) {
result = this.(MethodCallExpr).getGenericArgList().getTypeArg(i) result = this.(MethodCallExpr).getGenericArgList().getTypeArg(i)
} }
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
exists(TypeMention arg | result = arg.resolveTypeAt(path) | exists(TypeMention arg | result = arg.resolveTypeAt(path) |
arg = getExplicitTypeArgMention(CallExprImpl::getFunctionPath(this), apos.asTypeParam()) arg = getExplicitTypeArgMention(CallExprImpl::getFunctionPath(this), apos.asTypeParam())
or or
@@ -656,7 +666,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
) )
} }
AstNode getNodeAt(AccessPosition apos) { override AstNode getNodeAt(AccessPosition apos) {
exists(int p, boolean isMethodCall | exists(int p, boolean isMethodCall |
argPos(this, result, p, isMethodCall) and argPos(this, result, p, isMethodCall) and
apos = TPositionalAccessPosition(p, isMethodCall) apos = TPositionalAccessPosition(p, isMethodCall)
@@ -669,17 +679,42 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
apos = TReturnAccessPosition() apos = TReturnAccessPosition()
} }
Type getInferredType(AccessPosition apos, TypePath path) { override Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path) result = inferType(this.getNodeAt(apos), path)
} }
Declaration getTarget() { override Declaration getTarget() {
result = CallExprImpl::getResolvedFunction(this) result = CallExprImpl::getResolvedFunction(this)
or or
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
} }
} }
private class OperationAccess extends Access instanceof Operation {
OperationAccess() { super.isOverloaded(_, _) }
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
// The syntax for operators does not allow type arguments.
none()
}
override AstNode getNodeAt(AccessPosition apos) {
result = super.getOperand(0) and apos = TSelfAccessPosition()
or
result = super.getOperand(1) and apos = TPositionalAccessPosition(0, true)
or
result = this and apos = TReturnAccessPosition()
}
override Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
}
override Declaration getTarget() {
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
}
}
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) { predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
apos.isSelf() and apos.isSelf() and
dpos.isSelf() dpos.isSelf()
@@ -1059,6 +1094,26 @@ private module MethodCall {
pragma[nomagic] pragma[nomagic]
override Type getTypeAt(TypePath path) { result = inferType(receiver, path) } override Type getTypeAt(TypePath path) { result = inferType(receiver, path) }
} }
private class OperationMethodCall extends MethodCallImpl instanceof Operation {
TraitItemNode trait;
string methodName;
OperationMethodCall() { super.isOverloaded(trait, methodName) }
override string getMethodName() { result = methodName }
override int getArity() { result = this.(Operation).getNumberOfOperands() - 1 }
override Trait getTrait() { result = trait }
pragma[nomagic]
override Type getTypeAt(TypePath path) {
result = inferType(this.(BinaryExpr).getLhs(), path)
or
result = inferType(this.(PrefixExpr).getExpr(), path)
}
}
} }
import MethodCall import MethodCall

View File

@@ -765,11 +765,12 @@ mod method_supertraits {
} }
trait MyTrait2<Tr2>: MyTrait1<Tr2> { trait MyTrait2<Tr2>: MyTrait1<Tr2> {
#[rustfmt::skip]
fn m2(self) -> Tr2 fn m2(self) -> Tr2
where where
Self: Sized, Self: Sized,
{ {
if 1 + 1 > 2 { if 3 > 2 { // $ method=gt
self.m1() // $ method=MyTrait1::m1 self.m1() // $ method=MyTrait1::m1
} else { } else {
Self::m1(self) Self::m1(self)
@@ -778,11 +779,12 @@ mod method_supertraits {
} }
trait MyTrait3<Tr3>: MyTrait2<MyThing<Tr3>> { trait MyTrait3<Tr3>: MyTrait2<MyThing<Tr3>> {
#[rustfmt::skip]
fn m3(self) -> Tr3 fn m3(self) -> Tr3
where where
Self: Sized, Self: Sized,
{ {
if 1 + 1 > 2 { if 3 > 2 { // $ method=gt
self.m2().a // $ method=m2 $ fieldof=MyThing self.m2().a // $ method=m2 $ fieldof=MyThing
} else { } else {
Self::m2(self).a // $ fieldof=MyThing Self::m2(self).a // $ fieldof=MyThing
@@ -1024,21 +1026,24 @@ mod option_methods {
let x6 = MyOption::MySome(MyOption::<S>::MyNone()); let x6 = MyOption::MySome(MyOption::<S>::MyNone());
println!("{:?}", MyOption::<MyOption<S>>::flatten(x6)); println!("{:?}", MyOption::<MyOption<S>>::flatten(x6));
let from_if = if 1 + 1 > 2 { #[rustfmt::skip]
let from_if = if 3 > 2 { // $ method=gt
MyOption::MyNone() MyOption::MyNone()
} else { } else {
MyOption::MySome(S) MyOption::MySome(S)
}; };
println!("{:?}", from_if); println!("{:?}", from_if);
let from_match = match 1 + 1 > 2 { #[rustfmt::skip]
let from_match = match 3 > 2 { // $ method=gt
true => MyOption::MyNone(), true => MyOption::MyNone(),
false => MyOption::MySome(S), false => MyOption::MySome(S),
}; };
println!("{:?}", from_match); println!("{:?}", from_match);
#[rustfmt::skip]
let from_loop = loop { let from_loop = loop {
if 1 + 1 > 2 { if 3 > 2 { // $ method=gt
break MyOption::MyNone(); break MyOption::MyNone();
} }
break MyOption::MySome(S); break MyOption::MySome(S);
@@ -1240,7 +1245,7 @@ mod builtins {
pub fn f() { pub fn f() {
let x: i32 = 1; // $ type=x:i32 let x: i32 = 1; // $ type=x:i32
let y = 2; // $ type=y:i32 let y = 2; // $ type=y:i32
let z = x + y; // $ MISSING: type=z:i32 let z = x + y; // $ type=z:i32 method=add
let z = x.abs(); // $ method=abs $ type=z:i32 let z = x.abs(); // $ method=abs $ type=z:i32
let c = 'c'; // $ type=c:char let c = 'c'; // $ type=c:char
let hello = "Hello"; // $ type=hello:str let hello = "Hello"; // $ type=hello:str
@@ -1250,13 +1255,15 @@ mod builtins {
} }
} }
// Tests for non-overloaded operators.
mod operators { mod operators {
pub fn f() { pub fn f() {
let x = true && false; // $ type=x:bool let x = true && false; // $ type=x:bool
let y = true || false; // $ type=y:bool let y = true || false; // $ type=y:bool
let mut a; let mut a;
if 34 == 33 { let cond = 34 == 33; // $ method=eq
if cond {
let z = (a = 1); // $ type=z:() type=a:i32 let z = (a = 1); // $ type=z:() type=a:i32
} else { } else {
a = 2; // $ type=a:i32 a = 2; // $ type=a:i32
@@ -1265,6 +1272,364 @@ mod operators {
} }
} }
// Tests for overloaded operators.
mod overloadable_operators {
use std::ops::*;
// A vector type with overloaded operators.
#[derive(Debug, Copy, Clone)]
struct Vec2 {
x: i64,
y: i64,
}
// Implement all overloadable operators for Vec2
impl Add for Vec2 {
type Output = Self;
// Vec2::add
fn add(self, rhs: Self) -> Self {
Vec2 {
x: self.x + rhs.x, // $ fieldof=Vec2 method=add
y: self.y + rhs.y, // $ fieldof=Vec2 method=add
}
}
}
impl AddAssign for Vec2 {
// Vec2::add_assign
#[rustfmt::skip]
fn add_assign(&mut self, rhs: Self) {
self.x += rhs.x; // $ fieldof=Vec2 method=add_assign
self.y += rhs.y; // $ fieldof=Vec2 method=add_assign
}
}
impl Sub for Vec2 {
type Output = Self;
// Vec2::sub
fn sub(self, rhs: Self) -> Self {
Vec2 {
x: self.x - rhs.x, // $ fieldof=Vec2 method=sub
y: self.y - rhs.y, // $ fieldof=Vec2 method=sub
}
}
}
impl SubAssign for Vec2 {
// Vec2::sub_assign
#[rustfmt::skip]
fn sub_assign(&mut self, rhs: Self) {
self.x -= rhs.x; // $ fieldof=Vec2 method=sub_assign
self.y -= rhs.y; // $ fieldof=Vec2 method=sub_assign
}
}
impl Mul for Vec2 {
type Output = Self;
// Vec2::mul
fn mul(self, rhs: Self) -> Self {
Vec2 {
x: self.x * rhs.x, // $ fieldof=Vec2 method=mul
y: self.y * rhs.y, // $ fieldof=Vec2 method=mul
}
}
}
impl MulAssign for Vec2 {
// Vec2::mul_assign
fn mul_assign(&mut self, rhs: Self) {
self.x *= rhs.x; // $ fieldof=Vec2 method=mul_assign
self.y *= rhs.y; // $ fieldof=Vec2 method=mul_assign
}
}
impl Div for Vec2 {
type Output = Self;
// Vec2::div
fn div(self, rhs: Self) -> Self {
Vec2 {
x: self.x / rhs.x, // $ fieldof=Vec2 method=div
y: self.y / rhs.y, // $ fieldof=Vec2 method=div
}
}
}
impl DivAssign for Vec2 {
// Vec2::div_assign
fn div_assign(&mut self, rhs: Self) {
self.x /= rhs.x; // $ fieldof=Vec2 method=div_assign
self.y /= rhs.y; // $ fieldof=Vec2 method=div_assign
}
}
impl Rem for Vec2 {
type Output = Self;
// Vec2::rem
fn rem(self, rhs: Self) -> Self {
Vec2 {
x: self.x % rhs.x, // $ fieldof=Vec2 method=rem
y: self.y % rhs.y, // $ fieldof=Vec2 method=rem
}
}
}
impl RemAssign for Vec2 {
// Vec2::rem_assign
fn rem_assign(&mut self, rhs: Self) {
self.x %= rhs.x; // $ fieldof=Vec2 method=rem_assign
self.y %= rhs.y; // $ fieldof=Vec2 method=rem_assign
}
}
impl BitAnd for Vec2 {
type Output = Self;
// Vec2::bitand
fn bitand(self, rhs: Self) -> Self {
Vec2 {
x: self.x & rhs.x, // $ fieldof=Vec2 method=bitand
y: self.y & rhs.y, // $ fieldof=Vec2 method=bitand
}
}
}
impl BitAndAssign for Vec2 {
// Vec2::bitand_assign
fn bitand_assign(&mut self, rhs: Self) {
self.x &= rhs.x; // $ fieldof=Vec2 method=bitand_assign
self.y &= rhs.y; // $ fieldof=Vec2 method=bitand_assign
}
}
impl BitOr for Vec2 {
type Output = Self;
// Vec2::bitor
fn bitor(self, rhs: Self) -> Self {
Vec2 {
x: self.x | rhs.x, // $ fieldof=Vec2 method=bitor
y: self.y | rhs.y, // $ fieldof=Vec2 method=bitor
}
}
}
impl BitOrAssign for Vec2 {
// Vec2::bitor_assign
fn bitor_assign(&mut self, rhs: Self) {
self.x |= rhs.x; // $ fieldof=Vec2 method=bitor_assign
self.y |= rhs.y; // $ fieldof=Vec2 method=bitor_assign
}
}
impl BitXor for Vec2 {
type Output = Self;
// Vec2::bitxor
fn bitxor(self, rhs: Self) -> Self {
Vec2 {
x: self.x ^ rhs.x, // $ fieldof=Vec2 method=bitxor
y: self.y ^ rhs.y, // $ fieldof=Vec2 method=bitxor
}
}
}
impl BitXorAssign for Vec2 {
// Vec2::bitxor_assign
fn bitxor_assign(&mut self, rhs: Self) {
self.x ^= rhs.x; // $ fieldof=Vec2 method=bitxor_assign
self.y ^= rhs.y; // $ fieldof=Vec2 method=bitxor_assign
}
}
impl Shl<u32> for Vec2 {
type Output = Self;
// Vec2::shl
fn shl(self, rhs: u32) -> Self {
Vec2 {
x: self.x << rhs, // $ fieldof=Vec2 method=shl
y: self.y << rhs, // $ fieldof=Vec2 method=shl
}
}
}
impl ShlAssign<u32> for Vec2 {
// Vec2::shl_assign
fn shl_assign(&mut self, rhs: u32) {
self.x <<= rhs; // $ fieldof=Vec2 method=shl_assign
self.y <<= rhs; // $ fieldof=Vec2 method=shl_assign
}
}
impl Shr<u32> for Vec2 {
type Output = Self;
// Vec2::shr
fn shr(self, rhs: u32) -> Self {
Vec2 {
x: self.x >> rhs, // $ fieldof=Vec2 method=shr
y: self.y >> rhs, // $ fieldof=Vec2 method=shr
}
}
}
impl ShrAssign<u32> for Vec2 {
// Vec2::shr_assign
fn shr_assign(&mut self, rhs: u32) {
self.x >>= rhs; // $ fieldof=Vec2 method=shr_assign
self.y >>= rhs; // $ fieldof=Vec2 method=shr_assign
}
}
impl Neg for Vec2 {
type Output = Self;
// Vec2::neg
fn neg(self) -> Self {
Vec2 {
x: -self.x, // $ fieldof=Vec2 method=neg
y: -self.y, // $ fieldof=Vec2 method=neg
}
}
}
impl Not for Vec2 {
type Output = Self;
// Vec2::not
fn not(self) -> Self {
Vec2 {
x: !self.x, // $ fieldof=Vec2 method=not
y: !self.y, // $ fieldof=Vec2 method=not
}
}
}
impl PartialEq for Vec2 {
// Vec2::eq
fn eq(&self, other: &Self) -> bool {
self.x == other.x && self.y == other.y // $ fieldof=Vec2 method=eq
}
// Vec2::ne
fn ne(&self, other: &Self) -> bool {
self.x != other.x || self.y != other.y // $ fieldof=Vec2 method=ne
}
}
impl PartialOrd for Vec2 {
// Vec2::partial_cmp
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
(self.x + self.y).partial_cmp(&(other.x + other.y)) // $ fieldof=Vec2 method=partial_cmp method=add
}
// Vec2::lt
fn lt(&self, other: &Self) -> bool {
self.x < other.x && self.y < other.y // $ fieldof=Vec2 method=lt
}
// Vec2::le
fn le(&self, other: &Self) -> bool {
self.x <= other.x && self.y <= other.y // $ fieldof=Vec2 method=le
}
// Vec2::gt
fn gt(&self, other: &Self) -> bool {
self.x > other.x && self.y > other.y // $ fieldof=Vec2 method=gt
}
// Vec2::ge
fn ge(&self, other: &Self) -> bool {
self.x >= other.x && self.y >= other.y // $ fieldof=Vec2 method=ge
}
}
pub fn f() {
// Test for all overloadable operators on `i64`
// Comparison operators
let i64_eq = (1i64 == 2i64); // $ type=i64_eq:bool method=eq
let i64_ne = (3i64 != 4i64); // $ type=i64_ne:bool method=ne
let i64_lt = (5i64 < 6i64); // $ type=i64_lt:bool method=lt
let i64_le = (7i64 <= 8i64); // $ type=i64_le:bool method=le
let i64_gt = (9i64 > 10i64); // $ type=i64_gt:bool method=gt
let i64_ge = (11i64 >= 12i64); // $ type=i64_ge:bool method=ge
// Arithmetic operators
let i64_add = 13i64 + 14i64; // $ type=i64_add:i64 method=add
let i64_sub = 15i64 - 16i64; // $ type=i64_sub:i64 method=sub
let i64_mul = 17i64 * 18i64; // $ type=i64_mul:i64 method=mul
let i64_div = 19i64 / 20i64; // $ type=i64_div:i64 method=div
let i64_rem = 21i64 % 22i64; // $ type=i64_rem:i64 method=rem
// Arithmetic assignment operators
let mut i64_add_assign = 23i64;
i64_add_assign += 24i64; // $ method=add_assign
let mut i64_sub_assign = 25i64;
i64_sub_assign -= 26i64; // $ method=sub_assign
let mut i64_mul_assign = 27i64;
i64_mul_assign *= 28i64; // $ method=mul_assign
let mut i64_div_assign = 29i64;
i64_div_assign /= 30i64; // $ method=div_assign
let mut i64_rem_assign = 31i64;
i64_rem_assign %= 32i64; // $ method=rem_assign
// Bitwise operators
let i64_bitand = 33i64 & 34i64; // $ type=i64_bitand:i64 method=bitand
let i64_bitor = 35i64 | 36i64; // $ type=i64_bitor:i64 method=bitor
let i64_bitxor = 37i64 ^ 38i64; // $ type=i64_bitxor:i64 method=bitxor
let i64_shl = 39i64 << 40i64; // $ type=i64_shl:i64 method=shl
let i64_shr = 41i64 >> 42i64; // $ type=i64_shr:i64 method=shr
// Bitwise assignment operators
let mut i64_bitand_assign = 43i64;
i64_bitand_assign &= 44i64; // $ method=bitand_assign
let mut i64_bitor_assign = 45i64;
i64_bitor_assign |= 46i64; // $ method=bitor_assign
let mut i64_bitxor_assign = 47i64;
i64_bitxor_assign ^= 48i64; // $ method=bitxor_assign
let mut i64_shl_assign = 49i64;
i64_shl_assign <<= 50i64; // $ method=shl_assign
let mut i64_shr_assign = 51i64;
i64_shr_assign >>= 52i64; // $ method=shr_assign
let i64_neg = -53i64; // $ type=i64_neg:i64 method=neg
let i64_not = !54i64; // $ type=i64_not:i64 method=not
// Test for all overloadable operators on Vec2
let v1 = Vec2 { x: 1, y: 2 };
let v2 = Vec2 { x: 3, y: 4 };
// Comparison operators
let vec2_eq = v1 == v2; // $ type=vec2_eq:bool method=Vec2::eq
let vec2_ne = v1 != v2; // $ type=vec2_ne:bool method=Vec2::ne
let vec2_lt = v1 < v2; // $ type=vec2_lt:bool method=Vec2::lt
let vec2_le = v1 <= v2; // $ type=vec2_le:bool method=Vec2::le
let vec2_gt = v1 > v2; // $ type=vec2_gt:bool method=Vec2::gt
let vec2_ge = v1 >= v2; // $ type=vec2_ge:bool method=Vec2::ge
// Arithmetic operators
let vec2_add = v1 + v2; // $ type=vec2_add:Vec2 method=Vec2::add
let vec2_sub = v1 - v2; // $ type=vec2_sub:Vec2 method=Vec2::sub
let vec2_mul = v1 * v2; // $ type=vec2_mul:Vec2 method=Vec2::mul
let vec2_div = v1 / v2; // $ type=vec2_div:Vec2 method=Vec2::div
let vec2_rem = v1 % v2; // $ type=vec2_rem:Vec2 method=Vec2::rem
// Arithmetic assignment operators
let mut vec2_add_assign = v1;
vec2_add_assign += v2; // $ method=Vec2::add_assign
let mut vec2_sub_assign = v1;
vec2_sub_assign -= v2; // $ method=Vec2::sub_assign
let mut vec2_mul_assign = v1;
vec2_mul_assign *= v2; // $ method=Vec2::mul_assign
let mut vec2_div_assign = v1;
vec2_div_assign /= v2; // $ method=Vec2::div_assign
let mut vec2_rem_assign = v1;
vec2_rem_assign %= v2; // $ method=Vec2::rem_assign
// Bitwise operators
let vec2_bitand = v1 & v2; // $ type=vec2_bitand:Vec2 method=Vec2::bitand
let vec2_bitor = v1 | v2; // $ type=vec2_bitor:Vec2 method=Vec2::bitor
let vec2_bitxor = v1 ^ v2; // $ type=vec2_bitxor:Vec2 method=Vec2::bitxor
let vec2_shl = v1 << 1u32; // $ type=vec2_shl:Vec2 method=Vec2::shl
let vec2_shr = v1 >> 1u32; // $ type=vec2_shr:Vec2 method=Vec2::shr
// Bitwise assignment operators
let mut vec2_bitand_assign = v1;
vec2_bitand_assign &= v2; // $ method=Vec2::bitand_assign
let mut vec2_bitor_assign = v1;
vec2_bitor_assign |= v2; // $ method=Vec2::bitor_assign
let mut vec2_bitxor_assign = v1;
vec2_bitxor_assign ^= v2; // $ method=Vec2::bitxor_assign
let mut vec2_shl_assign = v1;
vec2_shl_assign <<= 1u32; // $ method=Vec2::shl_assign
let mut vec2_shr_assign = v1;
vec2_shr_assign >>= 1u32; // $ method=Vec2::shr_assign
// Prefix operators
let vec2_neg = -v1; // $ type=vec2_neg:Vec2 method=Vec2::neg
let vec2_not = !v1; // $ type=vec2_not:Vec2 method=Vec2::not
}
}
fn main() { fn main() {
field_access::f(); field_access::f();
method_impl::f(); method_impl::f();