Bazel: improved lazy lfs files

This reintroduces lazy lfs file rules that were removed in
https://github.com/github/codeql/pull/16117, now improved.

The new rules will make the actual file download go through bazel's
download manager, which includes:
* caching into the repository cache
* sane limiting of concurrent downloads
* retries

The bulk of the work is done by `git_lfs_probe.py`, which will use the
LFS protocol (with authentication via SSH) to output short lived
download URLs that can be consumed by `repository_ctx.download`.
This commit is contained in:
Paolo Tranquilli
2024-04-30 09:00:14 +02:00
parent a8f2cbc2b1
commit 677520aa8e
4 changed files with 196 additions and 0 deletions

5
.lfsconfig Normal file
View File

@@ -0,0 +1,5 @@
[lfs]
# codeql is publicly forked by many users, and we don't want any LFS file polluting their working
# copies. We therefore exclude everything by default.
# For files required by bazel builds, use rules in `misc/bazel/lfs.bzl` to download them on demand.
fetchinclude = /nothing

View File

View File

@@ -0,0 +1,112 @@
#!/usr/bin/env python3
"""
Probe lfs files.
For each source file provided as output, this will print:
* "local", if the source file is not an LFS pointer
* the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise
"""
import sys
import pathlib
import subprocess
import os
import shutil
import json
import urllib.request
from urllib.parse import urlparse
import re
sources = [pathlib.Path(arg).resolve() for arg in sys.argv[1:]]
source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources))
source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip()
def get_endpoint():
lfs_env = subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)
endpoint = ssh_server = ssh_path = None
endpoint_re = re.compile(r'Endpoint(?: \(\S+\))?=(\S+)')
ssh_re = re.compile(r'\s*SSH=(\S*):(.*)')
for line in lfs_env.splitlines():
m = endpoint_re.match(line)
if m:
if endpoint is None:
endpoint = m[1]
else:
break
m = ssh_re.match(line)
if m:
ssh_server, ssh_path = m.groups()
break
assert endpoint, f"no Endpoint= line found in git lfs env:\n{lfs_env}"
headers = {
"Content-Type": "application/vnd.git-lfs+json",
"Accept": "application/vnd.git-lfs+json",
}
if ssh_server:
ssh_command = shutil.which(os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh")))
assert ssh_command, "no ssh command found"
with subprocess.Popen([ssh_command, ssh_server, "git-lfs-authenticate", ssh_path, "download"],
stdout=subprocess.PIPE) as ssh:
resp = json.load(ssh.stdout)
assert ssh.wait() == 0, "ssh command failed"
endpoint = resp.get("href", endpoint)
for k, v in resp.get("header", {}).items():
headers[k.capitalize()] = v
url = urlparse(endpoint)
# this is how actions/checkout persist credentials
# see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63
auth = subprocess.run(["git", "config", f"http.{url.scheme}://{url.netloc}/.extraheader"], text=True,
stdout=subprocess.PIPE, cwd=source_dir).stdout.strip()
for l in auth.splitlines():
k, _, v = l.partition(": ")
headers[k.capitalize()] = v
if "GITHUB_TOKEN" in os.environ:
headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}"
return endpoint, headers
# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md
def get_locations(objects):
href, headers = get_endpoint()
indexes = [i for i, o in enumerate(objects) if o]
ret = ["local" for _ in objects]
req = urllib.request.Request(
f"{href}/objects/batch",
headers=headers,
data=json.dumps({
"operation": "download",
"transfers": ["basic"],
"objects": [o for o in objects if o],
"hash_algo": "sha256",
}).encode("ascii"),
)
with urllib.request.urlopen(req) as resp:
data = json.load(resp)
assert len(data["objects"]) == len(indexes), data
for i, resp in zip(indexes, data["objects"]):
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}'
return ret
def get_lfs_object(path):
with open(path, 'rb') as fileobj:
lfs_header = "version https://git-lfs.github.com/spec".encode()
actual_header = fileobj.read(len(lfs_header))
sha256 = size = None
if lfs_header != actual_header:
return None
for line in fileobj:
line = line.decode('ascii').strip()
if line.startswith("oid sha256:"):
sha256 = line[len("oid sha256:"):]
elif line.startswith("size "):
size = int(line[len("size "):])
if not (sha256 and line):
raise Exception("malformed pointer file")
return {"oid": sha256, "size": size}
objects = [get_lfs_object(src) for src in sources]
for resp in get_locations(objects):
print(resp)

79
misc/bazel/lfs.bzl Normal file
View File

@@ -0,0 +1,79 @@
def lfs_smudge(repository_ctx, srcs):
for src in srcs:
repository_ctx.watch(src)
script = Label("//misc/bazel/internal:git_lfs_probe.py")
python = repository_ctx.which("python3") or repository_ctx.which("python")
if not python:
fail("Neither python3 nor python executables found")
res = repository_ctx.execute([python, script] + srcs, quiet = True)
if res.return_code != 0:
fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr))
for src, loc in zip(srcs, res.stdout.splitlines()):
if loc == "local":
repository_ctx.symlink(src, src.basename)
else:
sha256, _, url = loc.partition(" ")
repository_ctx.download(url, src.basename, sha256 = sha256)
def _download_and_extract_lfs(repository_ctx):
attr = repository_ctx.attr
src = repository_ctx.path(attr.src)
if attr.build_file_content and attr.build_file:
fail("You should specify only one among build_file_content and build_file for rule @%s" % repository_ctx.name)
lfs_smudge(repository_ctx, [src])
repository_ctx.extract(src.basename, stripPrefix = attr.strip_prefix)
repository_ctx.delete(src.basename)
if attr.build_file_content:
repository_ctx.file("BUILD.bazel", attr.build_file_content)
elif attr.build_file:
repository_ctx.symlink(attr.build_file, "BUILD.bazel")
def _download_lfs(repository_ctx):
attr = repository_ctx.attr
if int(bool(attr.srcs)) + int(bool(attr.dir)) != 1:
fail("Exactly one between `srcs` and `dir` must be defined for @%s" % repository_ctx.name)
if attr.srcs:
srcs = [repository_ctx.path(src) for src in attr.srcs]
else:
dir = repository_ctx.path(attr.dir)
if not dir.is_dir:
fail("`dir` not a directory in @%s" % repository_ctx.name)
srcs = [f for f in dir.readdir() if not f.is_dir]
lfs_smudge(repository_ctx, srcs)
# with bzlmod the name is qualified with `~` separators, and we want the base name here
name = repository_ctx.name.split("~")[-1]
repository_ctx.file("BUILD.bazel", """
exports_files({files})
filegroup(
name = "{name}",
srcs = {files},
visibility = ["//visibility:public"],
)
""".format(name = name, files = repr([src.basename for src in srcs])))
lfs_archive = repository_rule(
doc = "Export the contents from an on-demand LFS archive. The corresponding path should be added to be ignored " +
"in `.lfsconfig`.",
implementation = _download_and_extract_lfs,
attrs = {
"src": attr.label(mandatory = True, doc = "Local path to the LFS archive to extract."),
"build_file_content": attr.string(doc = "The content for the BUILD file for this repository. " +
"Either build_file or build_file_content can be specified, but not both."),
"build_file": attr.label(doc = "The file to use as the BUILD file for this repository. " +
"Either build_file or build_file_content can be specified, but not both."),
"strip_prefix": attr.string(default = "", doc = "A directory prefix to strip from the extracted files. "),
},
)
lfs_files = repository_rule(
doc = "Export LFS files for on-demand download. Exactly one between `srcs` and `dir` must be defined. The " +
"corresponding paths should be added to be ignored in `.lfsconfig`.",
implementation = _download_lfs,
attrs = {
"srcs": attr.label_list(doc = "Local paths to the LFS files to export."),
"dir": attr.label(doc = "Local path to a directory containing LFS files to export. Only the direct contents " +
"of the directory are exported"),
},
)