Bazel: format git_lfs_probe.py

This commit is contained in:
Paolo Tranquilli
2024-08-07 11:57:35 +02:00
parent 42be9e98c8
commit b63bd2ad14

View File

@@ -42,24 +42,32 @@ class Endpoint:
opts = options()
sources = [p.resolve() for p in opts.sources]
source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources))
source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip()
source_dir = subprocess.check_output(
["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True
).strip()
def get_env(s, sep="="):
ret = {}
for m in re.finditer(fr'(.*?){sep}(.*)', s, re.M):
for m in re.finditer(rf"(.*?){sep}(.*)", s, re.M):
ret.setdefault(*m.groups())
return ret
def git(*args, **kwargs):
return subprocess.run(("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs).stdout.strip()
return subprocess.run(
("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs
).stdout.strip()
def get_endpoint():
lfs_env_items = iter(get_env(subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)).items())
endpoint = next(v for k, v in lfs_env_items if k.startswith('Endpoint'))
endpoint, _, _ = endpoint.partition(' ')
lfs_env_items = iter(
get_env(
subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)
).items()
)
endpoint = next(v for k, v in lfs_env_items if k.startswith("Endpoint"))
endpoint, _, _ = endpoint.partition(" ")
# only take the ssh endpoint if it follows directly after the first endpoint we found
# in a situation like
# Endpoint (a)=...
@@ -69,21 +77,32 @@ def get_endpoint():
following_key, following_value = next(lfs_env_items, (None, None))
ssh_endpoint = following_value if following_key == " SSH" else None
endpoint = Endpoint(endpoint, {
"Content-Type": "application/vnd.git-lfs+json",
"Accept": "application/vnd.git-lfs+json",
})
endpoint = Endpoint(
endpoint,
{
"Content-Type": "application/vnd.git-lfs+json",
"Accept": "application/vnd.git-lfs+json",
},
)
if ssh_endpoint:
# see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md
server, _, path = ssh_endpoint.partition(":")
ssh_command = shutil.which(os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh")))
ssh_command = shutil.which(
os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh"))
)
assert ssh_command, "no ssh command found"
resp = json.loads(subprocess.check_output([ssh_command,
"-oStrictHostKeyChecking=accept-new",
server,
"git-lfs-authenticate",
path,
"download"]))
resp = json.loads(
subprocess.check_output(
[
ssh_command,
"-oStrictHostKeyChecking=accept-new",
server,
"git-lfs-authenticate",
path,
"download",
]
)
)
endpoint.href = resp.get("href", endpoint)
endpoint.update_headers(resp.get("header", {}))
url = urlparse(endpoint.href)
@@ -96,10 +115,18 @@ def get_endpoint():
if "Authorization" not in endpoint.headers:
# last chance: use git credentials (possibly backed by a credential helper like the one installed by gh)
# see https://git-scm.com/docs/git-credential
credentials = get_env(git("credential", "fill", check=True,
# drop leading / from url.path
input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n"))
auth = base64.b64encode(f'{credentials["username"]}:{credentials["password"]}'.encode()).decode('ascii')
credentials = get_env(
git(
"credential",
"fill",
check=True,
# drop leading / from url.path
input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n",
)
)
auth = base64.b64encode(
f'{credentials["username"]}:{credentials["password"]}'.encode()
).decode("ascii")
endpoint.headers["Authorization"] = f"Basic {auth}"
return endpoint
@@ -129,23 +156,25 @@ def get_locations(objects):
)
with urllib.request.urlopen(req) as resp:
data = json.load(resp)
assert len(data["objects"]) == len(indexes), f"received {len(data)} objects, expected {len(indexes)}"
assert len(data["objects"]) == len(
indexes
), f"received {len(data)} objects, expected {len(indexes)}"
for i, resp in zip(indexes, data["objects"]):
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}'
return ret
def get_lfs_object(path):
with open(path, 'rb') as fileobj:
with open(path, "rb") as fileobj:
lfs_header = "version https://git-lfs.github.com/spec".encode()
actual_header = fileobj.read(len(lfs_header))
sha256 = size = None
if lfs_header != actual_header:
return None
data = get_env(fileobj.read().decode('ascii'), sep=' ')
assert data['oid'].startswith('sha256:'), f"unknown oid type: {data['oid']}"
_, _, sha256 = data['oid'].partition(':')
size = int(data['size'])
data = get_env(fileobj.read().decode("ascii"), sep=" ")
assert data["oid"].startswith("sha256:"), f"unknown oid type: {data['oid']}"
_, _, sha256 = data["oid"].partition(":")
size = int(data["size"])
return {"oid": sha256, "size": size}