diff --git a/misc/bazel/internal/git_lfs_probe.py b/misc/bazel/internal/git_lfs_probe.py index cecda623adf..01b201c5486 100755 --- a/misc/bazel/internal/git_lfs_probe.py +++ b/misc/bazel/internal/git_lfs_probe.py @@ -7,19 +7,19 @@ For each source file provided as input, this will print: * the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise If --hash-only is provided, the transient URL will not be fetched and printed """ - +import dataclasses import sys import pathlib import subprocess import os import shutil import json +import typing import urllib.request from urllib.parse import urlparse import re import base64 from dataclasses import dataclass -from typing import Dict import argparse @@ -32,76 +32,124 @@ def options(): @dataclass class Endpoint: + name: str href: str - headers: Dict[str, str] + ssh: typing.Optional[str] = None + headers: typing.Dict[str, str] = dataclasses.field(default_factory=dict) - def update_headers(self, d: Dict[str, str]): - self.headers.update((k.capitalize(), v) for k, v in d.items()) + def update_headers(self, d: typing.Iterable[typing.Tuple[str, str]]): + self.headers.update((k.capitalize(), v) for k, v in d) opts = options() sources = [p.resolve() for p in opts.sources] source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources)) -source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip() +source_dir = subprocess.check_output( + ["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True +).strip() -def get_env(s, sep="="): - ret = {} - for m in re.finditer(fr'(.*?){sep}(.*)', s, re.M): - ret.setdefault(*m.groups()) - return ret +def get_env(s: str, sep: str = "=") -> typing.Iterable[typing.Tuple[str, str]]: + for m in re.finditer(rf"(.*?){sep}(.*)", s, re.M): + yield m.groups() def git(*args, **kwargs): - return subprocess.run(("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs).stdout.strip() + proc = subprocess.run( + ("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs + ) + return proc.stdout.strip() if proc.returncode == 0 else None -def get_endpoint(): - lfs_env_items = iter(get_env(subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)).items()) - endpoint = next(v for k, v in lfs_env_items if k.startswith('Endpoint')) - endpoint, _, _ = endpoint.partition(' ') - # only take the ssh endpoint if it follows directly after the first endpoint we found - # in a situation like - # Endpoint (a)=... - # Endpoint (b)=... - # SSH=... - # we want to ignore the SSH endpoint, as it's not linked to the default (a) endpoint - following_key, following_value = next(lfs_env_items, (None, None)) - ssh_endpoint = following_value if following_key == " SSH" else None +endpoint_re = re.compile(r"^Endpoint(?: \((.*)\))?$") - endpoint = Endpoint(endpoint, { - "Content-Type": "application/vnd.git-lfs+json", - "Accept": "application/vnd.git-lfs+json", - }) - if ssh_endpoint: - # see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md - server, _, path = ssh_endpoint.partition(":") - ssh_command = shutil.which(os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh"))) - assert ssh_command, "no ssh command found" - resp = json.loads(subprocess.check_output([ssh_command, - "-oStrictHostKeyChecking=accept-new", - server, - "git-lfs-authenticate", - path, - "download"])) - endpoint.href = resp.get("href", endpoint) - endpoint.update_headers(resp.get("header", {})) - url = urlparse(endpoint.href) - # this is how actions/checkout persist credentials - # see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63 - auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader") - endpoint.update_headers(get_env(auth, sep=": ")) - if os.environ.get("GITHUB_TOKEN"): - endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}" - if "Authorization" not in endpoint.headers: - # last chance: use git credentials (possibly backed by a credential helper like the one installed by gh) - # see https://git-scm.com/docs/git-credential - credentials = get_env(git("credential", "fill", check=True, - # drop leading / from url.path - input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n")) - auth = base64.b64encode(f'{credentials["username"]}:{credentials["password"]}'.encode()).decode('ascii') - endpoint.headers["Authorization"] = f"Basic {auth}" - return endpoint + +def get_endpoint_addresses() -> typing.Iterable[Endpoint]: + """Get all lfs endpoints, including SSH if present""" + lfs_env_items = get_env( + subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir) + ) + current_endpoint = None + for k, v in lfs_env_items: + m = endpoint_re.match(k) + if m: + if current_endpoint: + yield current_endpoint + href, _, _ = v.partition(" ") + current_endpoint = Endpoint(name=m[1] or "default", href=href) + elif k == " SSH" and current_endpoint: + current_endpoint.ssh = v + if current_endpoint: + yield current_endpoint + + +def get_endpoints() -> typing.Iterable[Endpoint]: + for endpoint in get_endpoint_addresses(): + endpoint.headers = { + "Content-Type": "application/vnd.git-lfs+json", + "Accept": "application/vnd.git-lfs+json", + } + if endpoint.ssh: + # see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md + server, _, path = endpoint.ssh.partition(":") + ssh_command = shutil.which( + os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh")) + ) + assert ssh_command, "no ssh command found" + cmd = [ + ssh_command, + "-oStrictHostKeyChecking=accept-new", + server, + "git-lfs-authenticate", + path, + "download", + ] + try: + res = subprocess.run(cmd, stdout=subprocess.PIPE, timeout=15) + except subprocess.TimeoutExpired: + print( + f"WARNING: ssh timed out when connecting to {server}, ignoring {endpoint.name} endpoint", + file=sys.stderr, + ) + continue + if res.returncode != 0: + print( + f"WARNING: ssh failed when connecting to {server}, ignoring {endpoint.name} endpoint", + file=sys.stderr, + ) + continue + ssh_resp = json.loads(res.stdout) + endpoint.href = ssh_resp.get("href", endpoint) + endpoint.update_headers(ssh_resp.get("header", {}).items()) + url = urlparse(endpoint.href) + # this is how actions/checkout persist credentials + # see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63 + auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader") or "" + endpoint.update_headers(get_env(auth, sep=": ")) + if os.environ.get("GITHUB_TOKEN"): + endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}" + if "Authorization" not in endpoint.headers: + # last chance: use git credentials (possibly backed by a credential helper like the one installed by gh) + # see https://git-scm.com/docs/git-credential + credentials = git( + "credential", + "fill", + check=True, + # drop leading / from url.path + input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n", + ) + if credentials is None: + print( + f"WARNING: no authorization method found, ignoring {data.name} endpoint", + file=sys.stderr, + ) + continue + credentials = dict(get_env(credentials)) + auth = base64.b64encode( + f'{credentials["username"]}:{credentials["password"]}'.encode() + ).decode("ascii") + endpoint.headers["Authorization"] = f"Basic {auth}" + yield endpoint # see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md @@ -115,37 +163,44 @@ def get_locations(objects): for i in indexes: ret[i] = objects[i]["oid"] return ret - endpoint = get_endpoint() data = { "operation": "download", "transfers": ["basic"], "objects": [objects[i] for i in indexes], "hash_algo": "sha256", } - req = urllib.request.Request( - f"{endpoint.href}/objects/batch", - headers=endpoint.headers, - data=json.dumps(data).encode("ascii"), - ) - with urllib.request.urlopen(req) as resp: - data = json.load(resp) - assert len(data["objects"]) == len(indexes), f"received {len(data)} objects, expected {len(indexes)}" - for i, resp in zip(indexes, data["objects"]): - ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}' - return ret + for endpoint in get_endpoints(): + req = urllib.request.Request( + f"{endpoint.href}/objects/batch", + headers=endpoint.headers, + data=json.dumps(data).encode("ascii"), + ) + try: + with urllib.request.urlopen(req) as resp: + data = json.load(resp) + except urllib.request.HTTPError as e: + print(f"WARNING: encountered HTTPError {e}, ignoring endpoint {e.name}") + continue + assert len(data["objects"]) == len( + indexes + ), f"received {len(data)} objects, expected {len(indexes)}" + for i, resp in zip(indexes, data["objects"]): + ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}' + return ret + raise Exception(f"no valid endpoint found") def get_lfs_object(path): - with open(path, 'rb') as fileobj: + with open(path, "rb") as fileobj: lfs_header = "version https://git-lfs.github.com/spec".encode() actual_header = fileobj.read(len(lfs_header)) sha256 = size = None if lfs_header != actual_header: return None - data = get_env(fileobj.read().decode('ascii'), sep=' ') - assert data['oid'].startswith('sha256:'), f"unknown oid type: {data['oid']}" - _, _, sha256 = data['oid'].partition(':') - size = int(data['size']) + data = dict(get_env(fileobj.read().decode("ascii"), sep=" ")) + assert data["oid"].startswith("sha256:"), f"unknown oid type: {data['oid']}" + _, _, sha256 = data["oid"].partition(":") + size = int(data["size"]) return {"oid": sha256, "size": size}