From b63bd2ad14ebd51e1b39d4bed8b5302413c79ed0 Mon Sep 17 00:00:00 2001 From: Paolo Tranquilli Date: Wed, 7 Aug 2024 11:57:35 +0200 Subject: [PATCH 1/4] Bazel: format `git_lfs_probe.py` --- misc/bazel/internal/git_lfs_probe.py | 83 +++++++++++++++++++--------- 1 file changed, 56 insertions(+), 27 deletions(-) diff --git a/misc/bazel/internal/git_lfs_probe.py b/misc/bazel/internal/git_lfs_probe.py index cecda623adf..09dde110d0f 100755 --- a/misc/bazel/internal/git_lfs_probe.py +++ b/misc/bazel/internal/git_lfs_probe.py @@ -42,24 +42,32 @@ class Endpoint: opts = options() sources = [p.resolve() for p in opts.sources] source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources)) -source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip() +source_dir = subprocess.check_output( + ["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True +).strip() def get_env(s, sep="="): ret = {} - for m in re.finditer(fr'(.*?){sep}(.*)', s, re.M): + for m in re.finditer(rf"(.*?){sep}(.*)", s, re.M): ret.setdefault(*m.groups()) return ret def git(*args, **kwargs): - return subprocess.run(("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs).stdout.strip() + return subprocess.run( + ("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs + ).stdout.strip() def get_endpoint(): - lfs_env_items = iter(get_env(subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)).items()) - endpoint = next(v for k, v in lfs_env_items if k.startswith('Endpoint')) - endpoint, _, _ = endpoint.partition(' ') + lfs_env_items = iter( + get_env( + subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir) + ).items() + ) + endpoint = next(v for k, v in lfs_env_items if k.startswith("Endpoint")) + endpoint, _, _ = endpoint.partition(" ") # only take the ssh endpoint if it follows directly after the first endpoint we found # in a situation like # Endpoint (a)=... @@ -69,21 +77,32 @@ def get_endpoint(): following_key, following_value = next(lfs_env_items, (None, None)) ssh_endpoint = following_value if following_key == " SSH" else None - endpoint = Endpoint(endpoint, { - "Content-Type": "application/vnd.git-lfs+json", - "Accept": "application/vnd.git-lfs+json", - }) + endpoint = Endpoint( + endpoint, + { + "Content-Type": "application/vnd.git-lfs+json", + "Accept": "application/vnd.git-lfs+json", + }, + ) if ssh_endpoint: # see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md server, _, path = ssh_endpoint.partition(":") - ssh_command = shutil.which(os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh"))) + ssh_command = shutil.which( + os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh")) + ) assert ssh_command, "no ssh command found" - resp = json.loads(subprocess.check_output([ssh_command, - "-oStrictHostKeyChecking=accept-new", - server, - "git-lfs-authenticate", - path, - "download"])) + resp = json.loads( + subprocess.check_output( + [ + ssh_command, + "-oStrictHostKeyChecking=accept-new", + server, + "git-lfs-authenticate", + path, + "download", + ] + ) + ) endpoint.href = resp.get("href", endpoint) endpoint.update_headers(resp.get("header", {})) url = urlparse(endpoint.href) @@ -96,10 +115,18 @@ def get_endpoint(): if "Authorization" not in endpoint.headers: # last chance: use git credentials (possibly backed by a credential helper like the one installed by gh) # see https://git-scm.com/docs/git-credential - credentials = get_env(git("credential", "fill", check=True, - # drop leading / from url.path - input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n")) - auth = base64.b64encode(f'{credentials["username"]}:{credentials["password"]}'.encode()).decode('ascii') + credentials = get_env( + git( + "credential", + "fill", + check=True, + # drop leading / from url.path + input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n", + ) + ) + auth = base64.b64encode( + f'{credentials["username"]}:{credentials["password"]}'.encode() + ).decode("ascii") endpoint.headers["Authorization"] = f"Basic {auth}" return endpoint @@ -129,23 +156,25 @@ def get_locations(objects): ) with urllib.request.urlopen(req) as resp: data = json.load(resp) - assert len(data["objects"]) == len(indexes), f"received {len(data)} objects, expected {len(indexes)}" + assert len(data["objects"]) == len( + indexes + ), f"received {len(data)} objects, expected {len(indexes)}" for i, resp in zip(indexes, data["objects"]): ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}' return ret def get_lfs_object(path): - with open(path, 'rb') as fileobj: + with open(path, "rb") as fileobj: lfs_header = "version https://git-lfs.github.com/spec".encode() actual_header = fileobj.read(len(lfs_header)) sha256 = size = None if lfs_header != actual_header: return None - data = get_env(fileobj.read().decode('ascii'), sep=' ') - assert data['oid'].startswith('sha256:'), f"unknown oid type: {data['oid']}" - _, _, sha256 = data['oid'].partition(':') - size = int(data['size']) + data = get_env(fileobj.read().decode("ascii"), sep=" ") + assert data["oid"].startswith("sha256:"), f"unknown oid type: {data['oid']}" + _, _, sha256 = data["oid"].partition(":") + size = int(data["size"]) return {"oid": sha256, "size": size} From c576a116f5baf6a3b76c7c8a6c4f87673c301acf Mon Sep 17 00:00:00 2001 From: Paolo Tranquilli Date: Wed, 7 Aug 2024 12:38:43 +0200 Subject: [PATCH 2/4] Bazel: make `git_lfs_probe.py` try all available endpoints --- misc/bazel/internal/git_lfs_probe.py | 198 +++++++++++++++------------ 1 file changed, 112 insertions(+), 86 deletions(-) diff --git a/misc/bazel/internal/git_lfs_probe.py b/misc/bazel/internal/git_lfs_probe.py index 09dde110d0f..a6695a05db2 100755 --- a/misc/bazel/internal/git_lfs_probe.py +++ b/misc/bazel/internal/git_lfs_probe.py @@ -7,19 +7,19 @@ For each source file provided as input, this will print: * the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise If --hash-only is provided, the transient URL will not be fetched and printed """ - +import dataclasses import sys import pathlib import subprocess import os import shutil import json +import typing import urllib.request from urllib.parse import urlparse import re import base64 from dataclasses import dataclass -from typing import Dict import argparse @@ -32,11 +32,13 @@ def options(): @dataclass class Endpoint: + name: str href: str - headers: Dict[str, str] + ssh: str | None = None + headers: typing.Dict[str, str] = dataclasses.field(default_factory=dict) - def update_headers(self, d: Dict[str, str]): - self.headers.update((k.capitalize(), v) for k, v in d.items()) + def update_headers(self, d: typing.Iterable[tuple[str, str]]): + self.headers.update((k.capitalize(), v) for k, v in d) opts = options() @@ -47,88 +49,107 @@ source_dir = subprocess.check_output( ).strip() -def get_env(s, sep="="): - ret = {} +def get_env(s: str, sep: str = "=") -> typing.Iterable[tuple[str, str]]: for m in re.finditer(rf"(.*?){sep}(.*)", s, re.M): - ret.setdefault(*m.groups()) - return ret + yield m.groups() def git(*args, **kwargs): - return subprocess.run( + proc = subprocess.run( ("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs - ).stdout.strip() - - -def get_endpoint(): - lfs_env_items = iter( - get_env( - subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir) - ).items() ) - endpoint = next(v for k, v in lfs_env_items if k.startswith("Endpoint")) - endpoint, _, _ = endpoint.partition(" ") - # only take the ssh endpoint if it follows directly after the first endpoint we found - # in a situation like - # Endpoint (a)=... - # Endpoint (b)=... - # SSH=... - # we want to ignore the SSH endpoint, as it's not linked to the default (a) endpoint - following_key, following_value = next(lfs_env_items, (None, None)) - ssh_endpoint = following_value if following_key == " SSH" else None + return proc.stdout.strip() if proc.returncode == 0 else None - endpoint = Endpoint( - endpoint, - { + +endpoint_re = re.compile(r"^Endpoint(?: \((.*)\))?$") + + +def get_endpoint_addresses() -> typing.Iterable[Endpoint]: + """Get all lfs endpoints, including SSH if present""" + lfs_env_items = get_env( + subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir) + ) + current_endpoint = None + for k, v in lfs_env_items: + m = endpoint_re.match(k) + if m: + if current_endpoint: + yield current_endpoint + href, _, _ = v.partition(" ") + current_endpoint = Endpoint(name=m[1] or "default", href=href) + elif k == " SSH" and current_endpoint: + current_endpoint.ssh = v + if current_endpoint: + yield current_endpoint + + +def get_endpoints() -> typing.Iterable[Endpoint]: + for endpoint in get_endpoint_addresses(): + endpoint.headers = { "Content-Type": "application/vnd.git-lfs+json", "Accept": "application/vnd.git-lfs+json", - }, - ) - if ssh_endpoint: - # see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md - server, _, path = ssh_endpoint.partition(":") - ssh_command = shutil.which( - os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh")) - ) - assert ssh_command, "no ssh command found" - resp = json.loads( - subprocess.check_output( - [ - ssh_command, - "-oStrictHostKeyChecking=accept-new", - server, - "git-lfs-authenticate", - path, - "download", - ] + } + if endpoint.ssh: + # see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md + server, _, path = endpoint.ssh.partition(":") + ssh_command = shutil.which( + os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh")) ) - ) - endpoint.href = resp.get("href", endpoint) - endpoint.update_headers(resp.get("header", {})) - url = urlparse(endpoint.href) - # this is how actions/checkout persist credentials - # see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63 - auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader") - endpoint.update_headers(get_env(auth, sep=": ")) - if os.environ.get("GITHUB_TOKEN"): - endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}" - if "Authorization" not in endpoint.headers: - # last chance: use git credentials (possibly backed by a credential helper like the one installed by gh) - # see https://git-scm.com/docs/git-credential - credentials = get_env( - git( + assert ssh_command, "no ssh command found" + cmd = [ + ssh_command, + "-oStrictHostKeyChecking=accept-new", + server, + "git-lfs-authenticate", + path, + "download", + ] + try: + res = subprocess.run(cmd, stdout=subprocess.PIPE, timeout=15) + except subprocess.TimeoutExpired: + print( + f"WARNING: ssh timed out when connecting to {server}, ignoring {endpoint.name} endpoint", + file=sys.stderr, + ) + continue + if res.returncode != 0: + print( + f"WARNING: ssh failed when connecting to {server}, ignoring {endpoint.name} endpoint", + file=sys.stderr, + ) + continue + ssh_resp = json.loads(res.stdout) + endpoint.href = ssh_resp.get("href", endpoint) + endpoint.update_headers(ssh_resp.get("header", {}).items()) + url = urlparse(endpoint.href) + # this is how actions/checkout persist credentials + # see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63 + auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader") or "" + endpoint.update_headers(get_env(auth, sep=": ")) + if os.environ.get("GITHUB_TOKEN"): + endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}" + if "Authorization" not in endpoint.headers: + # last chance: use git credentials (possibly backed by a credential helper like the one installed by gh) + # see https://git-scm.com/docs/git-credential + credentials = git( "credential", "fill", check=True, # drop leading / from url.path input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n", ) - ) - auth = base64.b64encode( - f'{credentials["username"]}:{credentials["password"]}'.encode() - ).decode("ascii") - endpoint.headers["Authorization"] = f"Basic {auth}" - return endpoint + if credentials is None: + print( + f"WARNING: no authorization method found, ignoring {data.name} endpoint", + file=sys.stderr, + ) + continue + credentials = dict(get_env(credentials)) + auth = base64.b64encode( + f'{credentials["username"]}:{credentials["password"]}'.encode() + ).decode("ascii") + endpoint.headers["Authorization"] = f"Basic {auth}" + yield endpoint # see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md @@ -142,26 +163,31 @@ def get_locations(objects): for i in indexes: ret[i] = objects[i]["oid"] return ret - endpoint = get_endpoint() data = { "operation": "download", "transfers": ["basic"], "objects": [objects[i] for i in indexes], "hash_algo": "sha256", } - req = urllib.request.Request( - f"{endpoint.href}/objects/batch", - headers=endpoint.headers, - data=json.dumps(data).encode("ascii"), - ) - with urllib.request.urlopen(req) as resp: - data = json.load(resp) - assert len(data["objects"]) == len( - indexes - ), f"received {len(data)} objects, expected {len(indexes)}" - for i, resp in zip(indexes, data["objects"]): - ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}' - return ret + for endpoint in get_endpoints(): + req = urllib.request.Request( + f"{endpoint.href}/objects/batch", + headers=endpoint.headers, + data=json.dumps(data).encode("ascii"), + ) + try: + with urllib.request.urlopen(req) as resp: + data = json.load(resp) + except urllib.request.HTTPError as e: + print(f"WARNING: encountered HTTPError {e}, ignoring endpoint {e.name}") + continue + assert len(data["objects"]) == len( + indexes + ), f"received {len(data)} objects, expected {len(indexes)}" + for i, resp in zip(indexes, data["objects"]): + ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}' + return ret + raise Exception(f"no valid endpoint found") def get_lfs_object(path): @@ -171,7 +197,7 @@ def get_lfs_object(path): sha256 = size = None if lfs_header != actual_header: return None - data = get_env(fileobj.read().decode("ascii"), sep=" ") + data = dict(get_env(fileobj.read().decode("ascii"), sep=" ")) assert data["oid"].startswith("sha256:"), f"unknown oid type: {data['oid']}" _, _, sha256 = data["oid"].partition(":") size = int(data["size"]) From 58088b62dfefa1f55b92557228d9b35ebd951d13 Mon Sep 17 00:00:00 2001 From: Paolo Tranquilli Date: Wed, 7 Aug 2024 16:46:31 +0200 Subject: [PATCH 3/4] Bazel: make `git_lfs_probe.py` a bit more backward compatible --- misc/bazel/internal/git_lfs_probe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/misc/bazel/internal/git_lfs_probe.py b/misc/bazel/internal/git_lfs_probe.py index a6695a05db2..8b376966540 100755 --- a/misc/bazel/internal/git_lfs_probe.py +++ b/misc/bazel/internal/git_lfs_probe.py @@ -34,10 +34,10 @@ def options(): class Endpoint: name: str href: str - ssh: str | None = None + ssh: typing.Opional[str] = None headers: typing.Dict[str, str] = dataclasses.field(default_factory=dict) - def update_headers(self, d: typing.Iterable[tuple[str, str]]): + def update_headers(self, d: typing.Iterable[typing.Tuple[str, str]]): self.headers.update((k.capitalize(), v) for k, v in d) @@ -49,7 +49,7 @@ source_dir = subprocess.check_output( ).strip() -def get_env(s: str, sep: str = "=") -> typing.Iterable[tuple[str, str]]: +def get_env(s: str, sep: str = "=") -> typing.Iterable[typing.Tuple[str, str]]: for m in re.finditer(rf"(.*?){sep}(.*)", s, re.M): yield m.groups() From e451f2b34364084d7a15f4c3ad0f8c4668cb9868 Mon Sep 17 00:00:00 2001 From: Paolo Tranquilli Date: Wed, 7 Aug 2024 20:54:40 +0200 Subject: [PATCH 4/4] Bazel: fix typo --- misc/bazel/internal/git_lfs_probe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/misc/bazel/internal/git_lfs_probe.py b/misc/bazel/internal/git_lfs_probe.py index 8b376966540..01b201c5486 100755 --- a/misc/bazel/internal/git_lfs_probe.py +++ b/misc/bazel/internal/git_lfs_probe.py @@ -34,7 +34,7 @@ def options(): class Endpoint: name: str href: str - ssh: typing.Opional[str] = None + ssh: typing.Optional[str] = None headers: typing.Dict[str, str] = dataclasses.field(default_factory=dict) def update_headers(self, d: typing.Iterable[typing.Tuple[str, str]]):