Skip to content

Commit

Permalink
Reuse requests.Session when downloading compilers (#1414)
Browse files Browse the repository at this point in the history
* Reuse requests.Session when downloading compilers

* Actually reuse the session...
  • Loading branch information
mkst authored Jan 24, 2025
1 parent c3f4c4a commit b75f772
Showing 1 changed file with 119 additions and 29 deletions.
148 changes: 119 additions & 29 deletions backend/compilers/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,28 @@

import argparse
import datetime
import functools
import logging
import os
import platform
import queue
import shutil
import sys
import tempfile
import threading

from pathlib import Path

from multiprocessing import Pool

import requests
import yaml


logger = logging.getLogger(__name__)


def get_token(docker_registry, github_repo, docker_image):
def get_token(docker_registry, github_repo, docker_image, session=None):
token_url = f"https://{docker_registry}/token?scope=repository:{github_repo}/{docker_image}:pull"
resp = requests.get(token_url, timeout=10)
getter = requests if session is None else session
resp = getter.get(token_url, timeout=10)
if resp.status_code != 200:
# hopefully the image does not exist in remote registry
return None
Expand All @@ -34,8 +35,9 @@ def get_remote_image_digest(
docker_registry="ghcr.io",
github_repo="decompme/compilers",
tag="latest",
session=None,
):
token = get_token(docker_registry, github_repo, docker_image)
token = get_token(docker_registry, github_repo, docker_image, session=session)

image_url = (
f"https://{docker_registry}/v2/{github_repo}/{docker_image}/manifests/{tag}"
Expand All @@ -44,7 +46,8 @@ def get_remote_image_digest(
"Accept": "application/vnd.oci.image.index.v1+json",
"Authorization": f"Bearer {token}",
}
resp = requests.get(image_url, headers=headers, timeout=10)
getter = requests if session is None else session
resp = getter.get(image_url, headers=headers, timeout=10)
if resp.status_code != 200:
logger.debug(
"Unable to get image manifest for %s:%s from %s: %s",
Expand All @@ -67,6 +70,7 @@ def get_compiler_raw(
docker_registry="ghcr.io",
github_repo="decompme/compilers",
tag="latest",
session=None,
):
logger.info("Processing %s (%s)", compiler_id, platform_id)

Expand All @@ -79,7 +83,11 @@ def get_compiler_raw(

logger.debug("Checking for %s in registry", docker_image)
remote_image_digest = get_remote_image_digest(
docker_image, docker_registry=docker_registry, github_repo=github_repo, tag=tag
docker_image,
docker_registry=docker_registry,
github_repo=github_repo,
tag=tag,
session=session,
)
if remote_image_digest is None:
host_arch = platform.system().lower()
Expand All @@ -94,10 +102,11 @@ def get_compiler_raw(
docker_registry=docker_registry,
github_repo=github_repo,
tag=tag,
session=session,
)
if remote_image_digest is None:
logger.error("%s not found in registry!", docker_image)
return
return None

compiler_dir = compilers_dir / platform_id / compiler_id
image_digest = compiler_dir / ".image_digest"
Expand All @@ -110,10 +119,15 @@ def get_compiler_raw(
logger.debug(
"%s image is present and at latest version, skipping!", compiler_id
)
return
return None

# First, get a token to do our operations with
token = get_token(docker_registry, github_repo, docker_image)
token = get_token(
docker_registry,
github_repo,
docker_image,
session=session,
)

# Then, get the container image index. This will give us all the
# container images associated with this tag. There may be different
Expand All @@ -126,7 +140,8 @@ def get_compiler_raw(
"Accept": "application/vnd.oci.image.index.v1+json",
"Authorization": f"Bearer {token}",
}
resp = requests.get(image_url, headers=headers, timeout=10)
getter = requests if session is None else session
resp = getter.get(image_url, headers=headers, timeout=10)
if resp.status_code != 200:
return None
data = resp.json()
Expand All @@ -145,7 +160,7 @@ def get_compiler_raw(
url = (
f"https://{docker_registry}/v2/{github_repo}/{docker_image}/manifests/{digest}"
)
resp = requests.get(url, headers=headers, timeout=10)
resp = getter.get(url, headers=headers, timeout=10)
if resp.status_code != 200:
return None
data = resp.json()
Expand All @@ -157,7 +172,7 @@ def get_compiler_raw(
mime = layer["mediaType"]
url = f"https://{docker_registry}/v2/{github_repo}/{docker_image}/blobs/{digest}"

resp = requests.get(url, headers=headers, stream=True)
resp = getter.get(url, headers=headers, stream=True)
resp.raise_for_status()

# TODO: Get extension from mime
Expand Down Expand Up @@ -205,6 +220,59 @@ def get_compiler_raw(
return True


class DownloadThread(threading.Thread):
def __init__(
self,
download_queue: queue.Queue,
results_queue: queue.Queue,
*args,
compilers_dir=Path("/tmp"),
force=False,
docker_registry="ghcr.io",
github_repo="decompme/compilers",
**kwargs,
):
self.download_queue = download_queue
self.results_queue = results_queue

self.compilers_dir = compilers_dir
self.force = force
self.docker_registry = docker_registry
self.github_repo = github_repo

self.session = requests.Session()

super().__init__(*args, **kwargs)

def run(self):
while True:
try:
try:
item = self.download_queue.get_nowait()
except queue.Empty:
break
self.process_item(item)
self.download_queue.task_done()
except Exception as e:
logger.error("Exception thrown while processing item: %s", e)
break

def process_item(self, item):
platform_id, compiler_id = item

result = get_compiler_raw(
platform_id,
compiler_id,
compilers_dir=self.compilers_dir,
force=self.force,
docker_registry=self.docker_registry,
github_repo=self.github_repo,
session=self.session,
)

self.results_queue.put(result)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -258,7 +326,8 @@ def main():
)
sys.exit(0)

to_download = []
download_queue = queue.Queue()
results_queue = queue.Queue()

compilers_yaml = (
Path(os.path.dirname(os.path.realpath(__file__)))
Expand All @@ -279,34 +348,55 @@ def main():
compilers = filter(lambda x: x in args.compilers, compilers)

if platform_enabled:
to_download += [(platform_id, compiler) for compiler in compilers]
for compiler in compilers:
download_queue.put(
(platform_id, compiler),
)

if len(to_download) == 0:
if download_queue.qsize() == 0:
logger.warning(
"No platforms/compilers configured or enabled for host architecture (%s)",
host_arch,
)
return

start = datetime.datetime.now()
with Pool(processes=args.threads) as pool:
results = pool.starmap(
functools.partial(
get_compiler_raw,
compilers_dir=compilers_dir,
force=args.force,
docker_registry=args.docker_registry,
github_repo=args.github_repo,
),
to_download,
threads = []
for _ in range(args.threads):
thread = DownloadThread(
download_queue,
results_queue,
compilers_dir=compilers_dir,
force=args.force,
docker_registry=args.docker_registry,
github_repo=args.github_repo,
)
threads.append(thread)

start = datetime.datetime.now()

for thread in threads:
thread.start()

download_queue.join()

for thread in threads:
thread.join(timeout=0.1)

results = []
while True:
try:
item = results_queue.get(timeout=1)
results.append(item)
except queue.Empty:
break

end = datetime.datetime.now()

compilers_downloaded = len(list(filter(lambda x: x, results)))
logger.info(
"Updated %i / %i compiler(s) in %.2f second(s)",
compilers_downloaded,
len(to_download),
len(results),
(end - start).total_seconds(),
)

Expand Down

0 comments on commit b75f772

Please sign in to comment.