diff --git a/backend/compilers/download.py b/backend/compilers/download.py index 66fd11ee..643f26bf 100755 --- a/backend/compilers/download.py +++ b/backend/compilers/download.py @@ -2,27 +2,28 @@ import argparse import datetime -import functools import logging import os import platform +import queue import shutil import sys import tempfile +import threading from pathlib import Path -from multiprocessing import Pool - import requests import yaml + logger = logging.getLogger(__name__) -def get_token(docker_registry, github_repo, docker_image): +def get_token(docker_registry, github_repo, docker_image, session=None): token_url = f"https://{docker_registry}/token?scope=repository:{github_repo}/{docker_image}:pull" - resp = requests.get(token_url, timeout=10) + getter = requests if session is None else session + resp = getter.get(token_url, timeout=10) if resp.status_code != 200: # hopefully the image does not exist in remote registry return None @@ -34,8 +35,9 @@ def get_remote_image_digest( docker_registry="ghcr.io", github_repo="decompme/compilers", tag="latest", + session=None, ): - token = get_token(docker_registry, github_repo, docker_image) + token = get_token(docker_registry, github_repo, docker_image, session=session) image_url = ( f"https://{docker_registry}/v2/{github_repo}/{docker_image}/manifests/{tag}" @@ -44,7 +46,8 @@ def get_remote_image_digest( "Accept": "application/vnd.oci.image.index.v1+json", "Authorization": f"Bearer {token}", } - resp = requests.get(image_url, headers=headers, timeout=10) + getter = requests if session is None else session + resp = getter.get(image_url, headers=headers, timeout=10) if resp.status_code != 200: logger.debug( "Unable to get image manifest for %s:%s from %s: %s", @@ -67,6 +70,7 @@ def get_compiler_raw( docker_registry="ghcr.io", github_repo="decompme/compilers", tag="latest", + session=None, ): logger.info("Processing %s (%s)", compiler_id, platform_id) @@ -79,7 +83,11 @@ def get_compiler_raw( logger.debug("Checking for %s in registry", docker_image) remote_image_digest = get_remote_image_digest( - docker_image, docker_registry=docker_registry, github_repo=github_repo, tag=tag + docker_image, + docker_registry=docker_registry, + github_repo=github_repo, + tag=tag, + session=session, ) if remote_image_digest is None: host_arch = platform.system().lower() @@ -94,10 +102,11 @@ def get_compiler_raw( docker_registry=docker_registry, github_repo=github_repo, tag=tag, + session=session, ) if remote_image_digest is None: logger.error("%s not found in registry!", docker_image) - return + return None compiler_dir = compilers_dir / platform_id / compiler_id image_digest = compiler_dir / ".image_digest" @@ -110,10 +119,15 @@ def get_compiler_raw( logger.debug( "%s image is present and at latest version, skipping!", compiler_id ) - return + return None # First, get a token to do our operations with - token = get_token(docker_registry, github_repo, docker_image) + token = get_token( + docker_registry, + github_repo, + docker_image, + session=session, + ) # Then, get the container image index. This will give us all the # container images associated with this tag. There may be different @@ -126,7 +140,8 @@ def get_compiler_raw( "Accept": "application/vnd.oci.image.index.v1+json", "Authorization": f"Bearer {token}", } - resp = requests.get(image_url, headers=headers, timeout=10) + getter = requests if session is None else session + resp = getter.get(image_url, headers=headers, timeout=10) if resp.status_code != 200: return None data = resp.json() @@ -145,7 +160,7 @@ def get_compiler_raw( url = ( f"https://{docker_registry}/v2/{github_repo}/{docker_image}/manifests/{digest}" ) - resp = requests.get(url, headers=headers, timeout=10) + resp = getter.get(url, headers=headers, timeout=10) if resp.status_code != 200: return None data = resp.json() @@ -157,7 +172,7 @@ def get_compiler_raw( mime = layer["mediaType"] url = f"https://{docker_registry}/v2/{github_repo}/{docker_image}/blobs/{digest}" - resp = requests.get(url, headers=headers, stream=True) + resp = getter.get(url, headers=headers, stream=True) resp.raise_for_status() # TODO: Get extension from mime @@ -205,6 +220,59 @@ def get_compiler_raw( return True +class DownloadThread(threading.Thread): + def __init__( + self, + download_queue: queue.Queue, + results_queue: queue.Queue, + *args, + compilers_dir=Path("/tmp"), + force=False, + docker_registry="ghcr.io", + github_repo="decompme/compilers", + **kwargs, + ): + self.download_queue = download_queue + self.results_queue = results_queue + + self.compilers_dir = compilers_dir + self.force = force + self.docker_registry = docker_registry + self.github_repo = github_repo + + self.session = requests.Session() + + super().__init__(*args, **kwargs) + + def run(self): + while True: + try: + try: + item = self.download_queue.get_nowait() + except queue.Empty: + break + self.process_item(item) + self.download_queue.task_done() + except Exception as e: + logger.error("Exception thrown while processing item: %s", e) + break + + def process_item(self, item): + platform_id, compiler_id = item + + result = get_compiler_raw( + platform_id, + compiler_id, + compilers_dir=self.compilers_dir, + force=self.force, + docker_registry=self.docker_registry, + github_repo=self.github_repo, + session=self.session, + ) + + self.results_queue.put(result) + + def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -258,7 +326,8 @@ def main(): ) sys.exit(0) - to_download = [] + download_queue = queue.Queue() + results_queue = queue.Queue() compilers_yaml = ( Path(os.path.dirname(os.path.realpath(__file__))) @@ -279,34 +348,55 @@ def main(): compilers = filter(lambda x: x in args.compilers, compilers) if platform_enabled: - to_download += [(platform_id, compiler) for compiler in compilers] + for compiler in compilers: + download_queue.put( + (platform_id, compiler), + ) - if len(to_download) == 0: + if download_queue.qsize() == 0: logger.warning( "No platforms/compilers configured or enabled for host architecture (%s)", host_arch, ) return - start = datetime.datetime.now() - with Pool(processes=args.threads) as pool: - results = pool.starmap( - functools.partial( - get_compiler_raw, - compilers_dir=compilers_dir, - force=args.force, - docker_registry=args.docker_registry, - github_repo=args.github_repo, - ), - to_download, + threads = [] + for _ in range(args.threads): + thread = DownloadThread( + download_queue, + results_queue, + compilers_dir=compilers_dir, + force=args.force, + docker_registry=args.docker_registry, + github_repo=args.github_repo, ) + threads.append(thread) + + start = datetime.datetime.now() + + for thread in threads: + thread.start() + + download_queue.join() + + for thread in threads: + thread.join(timeout=0.1) + + results = [] + while True: + try: + item = results_queue.get(timeout=1) + results.append(item) + except queue.Empty: + break + end = datetime.datetime.now() compilers_downloaded = len(list(filter(lambda x: x, results))) logger.info( "Updated %i / %i compiler(s) in %.2f second(s)", compilers_downloaded, - len(to_download), + len(results), (end - start).total_seconds(), )