Bazel: make git_lfs_probe.py try all available endpoints

This commit is contained in:
Paolo Tranquilli
2024-08-07 12:38:43 +02:00
parent b63bd2ad14
commit c576a116f5

View File

@@ -7,19 +7,19 @@ For each source file provided as input, this will print:
* the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise
If --hash-only is provided, the transient URL will not be fetched and printed
"""
import dataclasses
import sys
import pathlib
import subprocess
import os
import shutil
import json
import typing
import urllib.request
from urllib.parse import urlparse
import re
import base64
from dataclasses import dataclass
from typing import Dict
import argparse
@@ -32,11 +32,13 @@ def options():
@dataclass
class Endpoint:
name: str
href: str
headers: Dict[str, str]
ssh: str | None = None
headers: typing.Dict[str, str] = dataclasses.field(default_factory=dict)
def update_headers(self, d: Dict[str, str]):
self.headers.update((k.capitalize(), v) for k, v in d.items())
def update_headers(self, d: typing.Iterable[tuple[str, str]]):
self.headers.update((k.capitalize(), v) for k, v in d)
opts = options()
@@ -47,88 +49,107 @@ source_dir = subprocess.check_output(
).strip()
def get_env(s, sep="="):
ret = {}
def get_env(s: str, sep: str = "=") -> typing.Iterable[tuple[str, str]]:
for m in re.finditer(rf"(.*?){sep}(.*)", s, re.M):
ret.setdefault(*m.groups())
return ret
yield m.groups()
def git(*args, **kwargs):
return subprocess.run(
proc = subprocess.run(
("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs
).stdout.strip()
def get_endpoint():
lfs_env_items = iter(
get_env(
subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)
).items()
)
endpoint = next(v for k, v in lfs_env_items if k.startswith("Endpoint"))
endpoint, _, _ = endpoint.partition(" ")
# only take the ssh endpoint if it follows directly after the first endpoint we found
# in a situation like
# Endpoint (a)=...
# Endpoint (b)=...
# SSH=...
# we want to ignore the SSH endpoint, as it's not linked to the default (a) endpoint
following_key, following_value = next(lfs_env_items, (None, None))
ssh_endpoint = following_value if following_key == " SSH" else None
return proc.stdout.strip() if proc.returncode == 0 else None
endpoint = Endpoint(
endpoint,
{
endpoint_re = re.compile(r"^Endpoint(?: \((.*)\))?$")
def get_endpoint_addresses() -> typing.Iterable[Endpoint]:
"""Get all lfs endpoints, including SSH if present"""
lfs_env_items = get_env(
subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)
)
current_endpoint = None
for k, v in lfs_env_items:
m = endpoint_re.match(k)
if m:
if current_endpoint:
yield current_endpoint
href, _, _ = v.partition(" ")
current_endpoint = Endpoint(name=m[1] or "default", href=href)
elif k == " SSH" and current_endpoint:
current_endpoint.ssh = v
if current_endpoint:
yield current_endpoint
def get_endpoints() -> typing.Iterable[Endpoint]:
for endpoint in get_endpoint_addresses():
endpoint.headers = {
"Content-Type": "application/vnd.git-lfs+json",
"Accept": "application/vnd.git-lfs+json",
},
)
if ssh_endpoint:
# see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md
server, _, path = ssh_endpoint.partition(":")
ssh_command = shutil.which(
os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh"))
)
assert ssh_command, "no ssh command found"
resp = json.loads(
subprocess.check_output(
[
ssh_command,
"-oStrictHostKeyChecking=accept-new",
server,
"git-lfs-authenticate",
path,
"download",
]
}
if endpoint.ssh:
# see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md
server, _, path = endpoint.ssh.partition(":")
ssh_command = shutil.which(
os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh"))
)
)
endpoint.href = resp.get("href", endpoint)
endpoint.update_headers(resp.get("header", {}))
url = urlparse(endpoint.href)
# this is how actions/checkout persist credentials
# see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63
auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader")
endpoint.update_headers(get_env(auth, sep=": "))
if os.environ.get("GITHUB_TOKEN"):
endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}"
if "Authorization" not in endpoint.headers:
# last chance: use git credentials (possibly backed by a credential helper like the one installed by gh)
# see https://git-scm.com/docs/git-credential
credentials = get_env(
git(
assert ssh_command, "no ssh command found"
cmd = [
ssh_command,
"-oStrictHostKeyChecking=accept-new",
server,
"git-lfs-authenticate",
path,
"download",
]
try:
res = subprocess.run(cmd, stdout=subprocess.PIPE, timeout=15)
except subprocess.TimeoutExpired:
print(
f"WARNING: ssh timed out when connecting to {server}, ignoring {endpoint.name} endpoint",
file=sys.stderr,
)
continue
if res.returncode != 0:
print(
f"WARNING: ssh failed when connecting to {server}, ignoring {endpoint.name} endpoint",
file=sys.stderr,
)
continue
ssh_resp = json.loads(res.stdout)
endpoint.href = ssh_resp.get("href", endpoint)
endpoint.update_headers(ssh_resp.get("header", {}).items())
url = urlparse(endpoint.href)
# this is how actions/checkout persist credentials
# see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63
auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader") or ""
endpoint.update_headers(get_env(auth, sep=": "))
if os.environ.get("GITHUB_TOKEN"):
endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}"
if "Authorization" not in endpoint.headers:
# last chance: use git credentials (possibly backed by a credential helper like the one installed by gh)
# see https://git-scm.com/docs/git-credential
credentials = git(
"credential",
"fill",
check=True,
# drop leading / from url.path
input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n",
)
)
auth = base64.b64encode(
f'{credentials["username"]}:{credentials["password"]}'.encode()
).decode("ascii")
endpoint.headers["Authorization"] = f"Basic {auth}"
return endpoint
if credentials is None:
print(
f"WARNING: no authorization method found, ignoring {data.name} endpoint",
file=sys.stderr,
)
continue
credentials = dict(get_env(credentials))
auth = base64.b64encode(
f'{credentials["username"]}:{credentials["password"]}'.encode()
).decode("ascii")
endpoint.headers["Authorization"] = f"Basic {auth}"
yield endpoint
# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md
@@ -142,26 +163,31 @@ def get_locations(objects):
for i in indexes:
ret[i] = objects[i]["oid"]
return ret
endpoint = get_endpoint()
data = {
"operation": "download",
"transfers": ["basic"],
"objects": [objects[i] for i in indexes],
"hash_algo": "sha256",
}
req = urllib.request.Request(
f"{endpoint.href}/objects/batch",
headers=endpoint.headers,
data=json.dumps(data).encode("ascii"),
)
with urllib.request.urlopen(req) as resp:
data = json.load(resp)
assert len(data["objects"]) == len(
indexes
), f"received {len(data)} objects, expected {len(indexes)}"
for i, resp in zip(indexes, data["objects"]):
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}'
return ret
for endpoint in get_endpoints():
req = urllib.request.Request(
f"{endpoint.href}/objects/batch",
headers=endpoint.headers,
data=json.dumps(data).encode("ascii"),
)
try:
with urllib.request.urlopen(req) as resp:
data = json.load(resp)
except urllib.request.HTTPError as e:
print(f"WARNING: encountered HTTPError {e}, ignoring endpoint {e.name}")
continue
assert len(data["objects"]) == len(
indexes
), f"received {len(data)} objects, expected {len(indexes)}"
for i, resp in zip(indexes, data["objects"]):
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}'
return ret
raise Exception(f"no valid endpoint found")
def get_lfs_object(path):
@@ -171,7 +197,7 @@ def get_lfs_object(path):
sha256 = size = None
if lfs_header != actual_header:
return None
data = get_env(fileobj.read().decode("ascii"), sep=" ")
data = dict(get_env(fileobj.read().decode("ascii"), sep=" "))
assert data["oid"].startswith("sha256:"), f"unknown oid type: {data['oid']}"
_, _, sha256 = data["oid"].partition(":")
size = int(data["size"])