Files
codeql/java/integration-tests-lib/mitm_proxy.py
Arthur Baars a35be08d18 Java: buildess proxy tests: add mitm_proxy.py
A mock implementation of an https man-in-the-middle proxy
2024-07-25 08:32:02 +02:00

174 lines
6.8 KiB
Python

import http.server
import sys
import os
import socket
import ssl
import random
from datetime import datetime, timedelta, timezone
from cryptography.hazmat.primitives import hashes, serialization
from cryptography import utils, x509
from cryptography.hazmat.primitives.asymmetric import rsa, dsa
import select
def generateCA(ca_cert_file, ca_key_file):
ca_key = dsa.generate_private_key(4096)
name = x509.Name([
x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "US"),
x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, "GitHub"),
x509.NameAttribute(x509.NameOID.COMMON_NAME, "GitHub CodeQL Proxy")])
ca_cert = x509.CertificateBuilder().subject_name(name).issuer_name(name)
ca_cert = ca_cert.public_key(ca_key.public_key())
ca_cert = ca_cert.serial_number(random.randint(50000000, 100000000))
ca_cert = ca_cert.not_valid_before(datetime.now(timezone.utc))
ca_cert = ca_cert.not_valid_after(
datetime.now(timezone.utc) + timedelta(days=3650))
ca_cert = ca_cert.add_extension(x509.BasicConstraints(
ca=True, path_length=None), critical=True)
ca_cert = ca_cert.add_extension(
x509.SubjectKeyIdentifier.from_public_key(ca_key.public_key()), critical=False)
ca_cert = ca_cert.sign(ca_key, hashes.SHA256())
with open(ca_cert_file, 'wb') as f:
f.write(ca_cert.public_bytes(encoding=serialization.Encoding.PEM))
with open(ca_key_file, 'wb') as f:
f.write(ca_key.private_bytes(encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption()))
def create_certificate(hostname):
pkey = rsa.generate_private_key(public_exponent=65537, key_size=2048)
subject = x509.Name(
[x509.NameAttribute(x509.NameOID.COMMON_NAME, hostname)])
cert = x509.CertificateBuilder()
cert = cert.subject_name(subject).issuer_name(ca_certificate.subject)
cert = cert.public_key(pkey.public_key())
cert = cert.serial_number(random.randint(50000000, 100000000))
cert = cert.not_valid_before(datetime.now(timezone.utc)).not_valid_after(
datetime.now(timezone.utc) + timedelta(days=3650))
cert = cert.add_extension(x509.BasicConstraints(
ca=False, path_length=None), critical=True)
cert = cert.add_extension(
x509.SubjectAlternativeName([x509.DNSName(hostname), x509.DNSName(f"*.{hostname}")]), critical=False)
cert = cert.sign(ca_key, hashes.SHA256())
return (cert, pkey)
class Handler(http.server.SimpleHTTPRequestHandler):
def check_auth(self):
username = os.getenv('PROXY_USER')
password = os.getenv('PROXY_PASSWORD')
if username is None or password is None:
return True
authorization = self.headers.get(
'Proxy-Authorization', self.headers.get('Authorization', ''))
authorization = authorization.split()
if len(authorization) == 2:
import base64
import binascii
auth_type = authorization[0]
if auth_type.lower() == "basic":
try:
authorization = authorization[1].encode('ascii')
authorization = base64.decodebytes(
authorization).decode('ascii')
except (binascii.Error, UnicodeError):
pass
else:
authorization = authorization.split(':')
if len(authorization) == 2:
return username == authorization[0] and password == authorization[1]
return False
def do_CONNECT(self):
if not self.check_auth():
self.send_response(
http.HTTPStatus.PROXY_AUTHENTICATION_REQUIRED)
self.send_header('Proxy-Authenticate', 'Basic realm="Proxy"')
self.end_headers()
return
# split self.path into host and port
host, port = self.path.split(':')
port = int(port)
self.send_response(http.HTTPStatus.OK, 'Connection established')
self.send_header('Connection', 'close')
self.end_headers()
self.mitm(host, port)
# man in the middle SSL connection
def mitm(self, host, port):
ssl_client_context = ssl.create_default_context(
purpose=ssl.Purpose.CLIENT_AUTH)
if not os.path.exists("certs/" + host + '.pem'):
cert, pkey = create_certificate(host)
with open("certs/" + host + '.pem', 'wb') as f:
f.write(pkey.private_bytes(encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption()))
f.write(cert.public_bytes(encoding=serialization.Encoding.PEM))
ssl_client_context.load_cert_chain("certs/" + host + '.pem')
ssl_client_context.load_verify_locations(ca_certificate_path)
# wrap self.connection in SSL
client = ssl_client_context.wrap_socket(
self.connection, server_side=True)
# create socket to host:port
remote = socket.create_connection(
(host, port))
# wrap socket in SSL
ssl_server_context = ssl.create_default_context(
purpose=ssl.Purpose.SERVER_AUTH)
remote = ssl_server_context.wrap_socket(remote, server_hostname=host)
try:
while True:
ready, _, _ = select.select(
[client, remote], [], [], 2.0)
if not ready:
break
for src in ready:
if src is client:
dst = remote
else:
dst = client
src.setblocking(False)
dst.setblocking(True)
pending = 8192
while pending:
try:
data = src.recv(pending)
except ssl.SSLWantReadError:
break
if not data:
return
pending = src.pending()
dst.sendall(data)
finally:
remote.close()
client.close()
def do_GET(self):
raise NotImplementedError()
if __name__ == '__main__':
port = int(sys.argv[1])
ca_certificate = None
ca_certificate_path = None
ca_key = None
if len(sys.argv) > 2:
ca_certificate_path = sys.argv[2]
with open(ca_certificate_path, 'rb') as f:
ca_certificate = x509.load_pem_x509_certificate(f.read())
with open(sys.argv[3], 'rb') as f:
ca_key = serialization.load_pem_private_key(
f.read(), password=None)
server_address = ('localhost', port)
httpd = http.server.HTTPServer(server_address, Handler)
httpd.serve_forever()