mirror of
https://github.com/github/codeql.git
synced 2025-12-16 16:53:25 +01:00
Bazel: make git_lfs_probe.py try all available endpoints
This commit is contained in:
@@ -7,19 +7,19 @@ For each source file provided as input, this will print:
|
||||
* the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise
|
||||
If --hash-only is provided, the transient URL will not be fetched and printed
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
import pathlib
|
||||
import subprocess
|
||||
import os
|
||||
import shutil
|
||||
import json
|
||||
import typing
|
||||
import urllib.request
|
||||
from urllib.parse import urlparse
|
||||
import re
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
import argparse
|
||||
|
||||
|
||||
@@ -32,11 +32,13 @@ def options():
|
||||
|
||||
@dataclass
|
||||
class Endpoint:
|
||||
name: str
|
||||
href: str
|
||||
headers: Dict[str, str]
|
||||
ssh: str | None = None
|
||||
headers: typing.Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
|
||||
def update_headers(self, d: Dict[str, str]):
|
||||
self.headers.update((k.capitalize(), v) for k, v in d.items())
|
||||
def update_headers(self, d: typing.Iterable[tuple[str, str]]):
|
||||
self.headers.update((k.capitalize(), v) for k, v in d)
|
||||
|
||||
|
||||
opts = options()
|
||||
@@ -47,88 +49,107 @@ source_dir = subprocess.check_output(
|
||||
).strip()
|
||||
|
||||
|
||||
def get_env(s, sep="="):
|
||||
ret = {}
|
||||
def get_env(s: str, sep: str = "=") -> typing.Iterable[tuple[str, str]]:
|
||||
for m in re.finditer(rf"(.*?){sep}(.*)", s, re.M):
|
||||
ret.setdefault(*m.groups())
|
||||
return ret
|
||||
yield m.groups()
|
||||
|
||||
|
||||
def git(*args, **kwargs):
|
||||
return subprocess.run(
|
||||
proc = subprocess.run(
|
||||
("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs
|
||||
).stdout.strip()
|
||||
|
||||
|
||||
def get_endpoint():
|
||||
lfs_env_items = iter(
|
||||
get_env(
|
||||
subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)
|
||||
).items()
|
||||
)
|
||||
endpoint = next(v for k, v in lfs_env_items if k.startswith("Endpoint"))
|
||||
endpoint, _, _ = endpoint.partition(" ")
|
||||
# only take the ssh endpoint if it follows directly after the first endpoint we found
|
||||
# in a situation like
|
||||
# Endpoint (a)=...
|
||||
# Endpoint (b)=...
|
||||
# SSH=...
|
||||
# we want to ignore the SSH endpoint, as it's not linked to the default (a) endpoint
|
||||
following_key, following_value = next(lfs_env_items, (None, None))
|
||||
ssh_endpoint = following_value if following_key == " SSH" else None
|
||||
return proc.stdout.strip() if proc.returncode == 0 else None
|
||||
|
||||
endpoint = Endpoint(
|
||||
endpoint,
|
||||
{
|
||||
|
||||
endpoint_re = re.compile(r"^Endpoint(?: \((.*)\))?$")
|
||||
|
||||
|
||||
def get_endpoint_addresses() -> typing.Iterable[Endpoint]:
|
||||
"""Get all lfs endpoints, including SSH if present"""
|
||||
lfs_env_items = get_env(
|
||||
subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)
|
||||
)
|
||||
current_endpoint = None
|
||||
for k, v in lfs_env_items:
|
||||
m = endpoint_re.match(k)
|
||||
if m:
|
||||
if current_endpoint:
|
||||
yield current_endpoint
|
||||
href, _, _ = v.partition(" ")
|
||||
current_endpoint = Endpoint(name=m[1] or "default", href=href)
|
||||
elif k == " SSH" and current_endpoint:
|
||||
current_endpoint.ssh = v
|
||||
if current_endpoint:
|
||||
yield current_endpoint
|
||||
|
||||
|
||||
def get_endpoints() -> typing.Iterable[Endpoint]:
|
||||
for endpoint in get_endpoint_addresses():
|
||||
endpoint.headers = {
|
||||
"Content-Type": "application/vnd.git-lfs+json",
|
||||
"Accept": "application/vnd.git-lfs+json",
|
||||
},
|
||||
)
|
||||
if ssh_endpoint:
|
||||
# see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md
|
||||
server, _, path = ssh_endpoint.partition(":")
|
||||
ssh_command = shutil.which(
|
||||
os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh"))
|
||||
)
|
||||
assert ssh_command, "no ssh command found"
|
||||
resp = json.loads(
|
||||
subprocess.check_output(
|
||||
[
|
||||
ssh_command,
|
||||
"-oStrictHostKeyChecking=accept-new",
|
||||
server,
|
||||
"git-lfs-authenticate",
|
||||
path,
|
||||
"download",
|
||||
]
|
||||
}
|
||||
if endpoint.ssh:
|
||||
# see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md
|
||||
server, _, path = endpoint.ssh.partition(":")
|
||||
ssh_command = shutil.which(
|
||||
os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh"))
|
||||
)
|
||||
)
|
||||
endpoint.href = resp.get("href", endpoint)
|
||||
endpoint.update_headers(resp.get("header", {}))
|
||||
url = urlparse(endpoint.href)
|
||||
# this is how actions/checkout persist credentials
|
||||
# see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63
|
||||
auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader")
|
||||
endpoint.update_headers(get_env(auth, sep=": "))
|
||||
if os.environ.get("GITHUB_TOKEN"):
|
||||
endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}"
|
||||
if "Authorization" not in endpoint.headers:
|
||||
# last chance: use git credentials (possibly backed by a credential helper like the one installed by gh)
|
||||
# see https://git-scm.com/docs/git-credential
|
||||
credentials = get_env(
|
||||
git(
|
||||
assert ssh_command, "no ssh command found"
|
||||
cmd = [
|
||||
ssh_command,
|
||||
"-oStrictHostKeyChecking=accept-new",
|
||||
server,
|
||||
"git-lfs-authenticate",
|
||||
path,
|
||||
"download",
|
||||
]
|
||||
try:
|
||||
res = subprocess.run(cmd, stdout=subprocess.PIPE, timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
f"WARNING: ssh timed out when connecting to {server}, ignoring {endpoint.name} endpoint",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
if res.returncode != 0:
|
||||
print(
|
||||
f"WARNING: ssh failed when connecting to {server}, ignoring {endpoint.name} endpoint",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
ssh_resp = json.loads(res.stdout)
|
||||
endpoint.href = ssh_resp.get("href", endpoint)
|
||||
endpoint.update_headers(ssh_resp.get("header", {}).items())
|
||||
url = urlparse(endpoint.href)
|
||||
# this is how actions/checkout persist credentials
|
||||
# see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63
|
||||
auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader") or ""
|
||||
endpoint.update_headers(get_env(auth, sep=": "))
|
||||
if os.environ.get("GITHUB_TOKEN"):
|
||||
endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}"
|
||||
if "Authorization" not in endpoint.headers:
|
||||
# last chance: use git credentials (possibly backed by a credential helper like the one installed by gh)
|
||||
# see https://git-scm.com/docs/git-credential
|
||||
credentials = git(
|
||||
"credential",
|
||||
"fill",
|
||||
check=True,
|
||||
# drop leading / from url.path
|
||||
input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n",
|
||||
)
|
||||
)
|
||||
auth = base64.b64encode(
|
||||
f'{credentials["username"]}:{credentials["password"]}'.encode()
|
||||
).decode("ascii")
|
||||
endpoint.headers["Authorization"] = f"Basic {auth}"
|
||||
return endpoint
|
||||
if credentials is None:
|
||||
print(
|
||||
f"WARNING: no authorization method found, ignoring {data.name} endpoint",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
credentials = dict(get_env(credentials))
|
||||
auth = base64.b64encode(
|
||||
f'{credentials["username"]}:{credentials["password"]}'.encode()
|
||||
).decode("ascii")
|
||||
endpoint.headers["Authorization"] = f"Basic {auth}"
|
||||
yield endpoint
|
||||
|
||||
|
||||
# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md
|
||||
@@ -142,26 +163,31 @@ def get_locations(objects):
|
||||
for i in indexes:
|
||||
ret[i] = objects[i]["oid"]
|
||||
return ret
|
||||
endpoint = get_endpoint()
|
||||
data = {
|
||||
"operation": "download",
|
||||
"transfers": ["basic"],
|
||||
"objects": [objects[i] for i in indexes],
|
||||
"hash_algo": "sha256",
|
||||
}
|
||||
req = urllib.request.Request(
|
||||
f"{endpoint.href}/objects/batch",
|
||||
headers=endpoint.headers,
|
||||
data=json.dumps(data).encode("ascii"),
|
||||
)
|
||||
with urllib.request.urlopen(req) as resp:
|
||||
data = json.load(resp)
|
||||
assert len(data["objects"]) == len(
|
||||
indexes
|
||||
), f"received {len(data)} objects, expected {len(indexes)}"
|
||||
for i, resp in zip(indexes, data["objects"]):
|
||||
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}'
|
||||
return ret
|
||||
for endpoint in get_endpoints():
|
||||
req = urllib.request.Request(
|
||||
f"{endpoint.href}/objects/batch",
|
||||
headers=endpoint.headers,
|
||||
data=json.dumps(data).encode("ascii"),
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req) as resp:
|
||||
data = json.load(resp)
|
||||
except urllib.request.HTTPError as e:
|
||||
print(f"WARNING: encountered HTTPError {e}, ignoring endpoint {e.name}")
|
||||
continue
|
||||
assert len(data["objects"]) == len(
|
||||
indexes
|
||||
), f"received {len(data)} objects, expected {len(indexes)}"
|
||||
for i, resp in zip(indexes, data["objects"]):
|
||||
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}'
|
||||
return ret
|
||||
raise Exception(f"no valid endpoint found")
|
||||
|
||||
|
||||
def get_lfs_object(path):
|
||||
@@ -171,7 +197,7 @@ def get_lfs_object(path):
|
||||
sha256 = size = None
|
||||
if lfs_header != actual_header:
|
||||
return None
|
||||
data = get_env(fileobj.read().decode("ascii"), sep=" ")
|
||||
data = dict(get_env(fileobj.read().decode("ascii"), sep=" "))
|
||||
assert data["oid"].startswith("sha256:"), f"unknown oid type: {data['oid']}"
|
||||
_, _, sha256 = data["oid"].partition(":")
|
||||
size = int(data["size"])
|
||||
|
||||
Reference in New Issue
Block a user