From a0a3850e1ac2722314f10e2658ee35e875692918 Mon Sep 17 00:00:00 2001 From: Juho Inkinen <34240031+juhoinkinen@users.noreply.github.com> Date: Tue, 23 Apr 2024 09:09:26 +0300 Subject: [PATCH] Move functions for HuggingFaceHub interactions to own file --- annif/cli.py | 22 ++--- annif/cli_util.py | 229 +------------------------------------------ annif/hfh_util.py | 240 ++++++++++++++++++++++++++++++++++++++++++++++ tests/test_cli.py | 21 ++-- 4 files changed, 263 insertions(+), 249 deletions(-) create mode 100644 annif/hfh_util.py diff --git a/annif/cli.py b/annif/cli.py index 3577db215..cc62a0f96 100644 --- a/annif/cli.py +++ b/annif/cli.py @@ -17,7 +17,7 @@ import annif.parallel import annif.project import annif.registry -from annif import cli_util +from annif import cli_util, hfh_util from annif.exception import ( NotInitializedException, NotSupportedException, @@ -617,7 +617,7 @@ def run_upload(project_ids_pattern, repo_id, token, revision, commit_message): from huggingface_hub import HfApi from huggingface_hub.utils import HfHubHTTPError, HFValidationError - projects = cli_util.get_matching_projects(project_ids_pattern) + projects = hfh_util.get_matching_projects(project_ids_pattern) click.echo(f"Uploading project(s): {', '.join([p.project_id for p in projects])}") commit_message = ( @@ -628,7 +628,7 @@ def run_upload(project_ids_pattern, repo_id, token, revision, commit_message): fobjs, operations = [], [] try: - fobjs, operations = cli_util.prepare_commits(projects, repo_id) + fobjs, operations = hfh_util.prepare_commits(projects, repo_id) api = HfApi() api.create_commit( repo_id=repo_id, @@ -680,28 +680,28 @@ def run_download(project_ids_pattern, repo_id, token, revision, force): be given with options. """ - project_ids = cli_util.get_matching_project_ids_from_hf_hub( + project_ids = hfh_util.get_matching_project_ids_from_hf_hub( project_ids_pattern, repo_id, token, revision ) click.echo(f"Downloading project(s): {', '.join(project_ids)}") vocab_ids = set() for project_id in project_ids: - project_zip_cache_path = cli_util.download_from_hf_hub( + project_zip_cache_path = hfh_util.download_from_hf_hub( f"projects/{project_id}.zip", repo_id, token, revision ) - cli_util.unzip_archive(project_zip_cache_path, force) - config_file_cache_path = cli_util.download_from_hf_hub( + hfh_util.unzip_archive(project_zip_cache_path, force) + config_file_cache_path = hfh_util.download_from_hf_hub( f"{project_id}.cfg", repo_id, token, revision ) - vocab_ids.add(cli_util.get_vocab_id_from_config(config_file_cache_path)) - cli_util.copy_project_config(config_file_cache_path, force) + vocab_ids.add(hfh_util.get_vocab_id_from_config(config_file_cache_path)) + hfh_util.copy_project_config(config_file_cache_path, force) for vocab_id in vocab_ids: - vocab_zip_cache_path = cli_util.download_from_hf_hub( + vocab_zip_cache_path = hfh_util.download_from_hf_hub( f"vocabs/{vocab_id}.zip", repo_id, token, revision ) - cli_util.unzip_archive(vocab_zip_cache_path, force) + hfh_util.unzip_archive(vocab_zip_cache_path, force) @cli.command("completion") diff --git a/annif/cli_util.py b/annif/cli_util.py index c0812bd6a..c47196169 100644 --- a/annif/cli_util.py +++ b/annif/cli_util.py @@ -2,20 +2,11 @@ from __future__ import annotations -import binascii import collections -import configparser -import importlib import io import itertools import os -import pathlib -import shutil import sys -import tempfile -import time -import zipfile -from fnmatch import fnmatch from typing import TYPE_CHECKING import click @@ -23,12 +14,11 @@ from flask import current_app import annif -from annif.exception import ConfigurationException, OperationFailedException +from annif.exception import ConfigurationException from annif.project import Access if TYPE_CHECKING: from datetime import datetime - from typing import Any from click.core import Argument, Context, Option @@ -241,223 +231,6 @@ def generate_filter_params(filter_batch_max_limit: int) -> list[tuple[int, float return list(itertools.product(limits, thresholds)) -def get_matching_projects(pattern: str) -> list[AnnifProject]: - """ - Get projects that match the given pattern. - """ - return [ - proj - for proj in annif.registry.get_projects(min_access=Access.private).values() - if fnmatch(proj.project_id, pattern) - ] - - -def prepare_commits(projects: list[AnnifProject], repo_id: str) -> tuple[list, list]: - """Prepare and pre-upload data and config commit operations for projects to a - Hugging Face Hub repository.""" - from huggingface_hub import preupload_lfs_files - - fobjs, operations = [], [] - data_dirs = {p.datadir for p in projects} - vocab_dirs = {p.vocab.datadir for p in projects} - all_dirs = data_dirs.union(vocab_dirs) - - for data_dir in all_dirs: - fobj, operation = _prepare_datadir_commit(data_dir) - preupload_lfs_files(repo_id, additions=[operation]) - fobjs.append(fobj) - operations.append(operation) - - for project in projects: - fobj, operation = _prepare_config_commit(project) - fobjs.append(fobj) - operations.append(operation) - - return fobjs, operations - - -def _prepare_datadir_commit(data_dir: str) -> tuple[io.BufferedRandom, Any]: - from huggingface_hub import CommitOperationAdd - - zip_repo_path = data_dir.split(os.path.sep, 1)[1] + ".zip" - fobj = _archive_dir(data_dir) - operation = CommitOperationAdd(path_in_repo=zip_repo_path, path_or_fileobj=fobj) - return fobj, operation - - -def _prepare_config_commit(project: AnnifProject) -> tuple[io.BytesIO, Any]: - from huggingface_hub import CommitOperationAdd - - config_repo_path = project.project_id + ".cfg" - fobj = _get_project_config(project) - operation = CommitOperationAdd(path_in_repo=config_repo_path, path_or_fileobj=fobj) - return fobj, operation - - -def _is_train_file(fname: str) -> bool: - train_file_patterns = ("-train", "tmp-") - for pat in train_file_patterns: - if pat in fname: - return True - return False - - -def _archive_dir(data_dir: str) -> io.BufferedRandom: - fp = tempfile.TemporaryFile() - path = pathlib.Path(data_dir) - fpaths = [fpath for fpath in path.glob("**/*") if not _is_train_file(fpath.name)] - with zipfile.ZipFile(fp, mode="w") as zfile: - zfile.comment = bytes( - f"Archived by Annif {importlib.metadata.version('annif')}", - encoding="utf-8", - ) - for fpath in fpaths: - logger.debug(f"Adding {fpath}") - arcname = os.path.join(*fpath.parts[1:]) - zfile.write(fpath, arcname=arcname) - fp.seek(0) - return fp - - -def _get_project_config(project: AnnifProject) -> io.BytesIO: - fp = tempfile.TemporaryFile(mode="w+t") - config = configparser.ConfigParser() - config[project.project_id] = project.config - config.write(fp) # This needs tempfile in text mode - fp.seek(0) - # But for upload fobj needs to be in binary mode - return io.BytesIO(fp.read().encode("utf8")) - - -def get_matching_project_ids_from_hf_hub( - project_ids_pattern: str, repo_id: str, token, revision: str -) -> list[str]: - """Get project IDs of the projects in a Hugging Face Model Hub repository that match - the given pattern.""" - all_repo_file_paths = _list_files_in_hf_hub(repo_id, token, revision) - return [ - path.rsplit(".cfg")[0] - for path in all_repo_file_paths - if fnmatch(path, f"{project_ids_pattern}.cfg") - ] - - -def _list_files_in_hf_hub(repo_id: str, token: str, revision: str) -> list[str]: - from huggingface_hub import list_repo_files - from huggingface_hub.utils import HfHubHTTPError, HFValidationError - - try: - return [ - repofile - for repofile in list_repo_files( - repo_id=repo_id, token=token, revision=revision - ) - ] - except (HfHubHTTPError, HFValidationError) as err: - raise OperationFailedException(str(err)) - - -def download_from_hf_hub( - filename: str, repo_id: str, token: str, revision: str -) -> list[str]: - from huggingface_hub import hf_hub_download - from huggingface_hub.utils import HfHubHTTPError, HFValidationError - - try: - return hf_hub_download( - repo_id=repo_id, - filename=filename, - token=token, - revision=revision, - ) - except (HfHubHTTPError, HFValidationError) as err: - raise OperationFailedException(str(err)) - - -def unzip_archive(src_path: str, force: bool) -> None: - """Unzip a zip archive of projects and vocabularies to a directory, by - default data/ under current directory.""" - datadir = current_app.config["DATADIR"] - with zipfile.ZipFile(src_path, "r") as zfile: - archive_comment = str(zfile.comment, encoding="utf-8") - logger.debug( - f'Extracting archive {src_path}; archive comment: "{archive_comment}"' - ) - for member in zfile.infolist(): - _unzip_member(zfile, member, datadir, force) - - -def _unzip_member( - zfile: zipfile.ZipFile, member: zipfile.ZipInfo, datadir: str, force: bool -) -> None: - dest_path = os.path.join(datadir, member.filename) - if os.path.exists(dest_path) and not force: - _handle_existing_file(member, dest_path) - return - logger.debug(f"Unzipping to {dest_path}") - zfile.extract(member, path=datadir) - _restore_timestamps(member, dest_path) - - -def _handle_existing_file(member: zipfile.ZipInfo, dest_path: str) -> None: - if _are_identical_member_and_file(member, dest_path): - logger.debug(f"Skipping unzip to {dest_path}; already in place") - else: - click.echo(f"Not overwriting {dest_path} (use --force to override)") - - -def _are_identical_member_and_file(member: zipfile.ZipInfo, dest_path: str) -> bool: - path_crc = _compute_crc32(dest_path) - return path_crc == member.CRC - - -def _restore_timestamps(member: zipfile.ZipInfo, dest_path: str) -> None: - date_time = time.mktime(member.date_time + (0, 0, -1)) - os.utime(dest_path, (date_time, date_time)) - - -def copy_project_config(src_path: str, force: bool) -> None: - """Copy a given project configuration file to projects.d/ directory.""" - project_configs_dest_dir = "projects.d" - os.makedirs(project_configs_dest_dir, exist_ok=True) - - dest_path = os.path.join(project_configs_dest_dir, os.path.basename(src_path)) - if os.path.exists(dest_path) and not force: - if _are_identical_files(src_path, dest_path): - logger.debug(f"Skipping copy to {dest_path}; already in place") - else: - click.echo(f"Not overwriting {dest_path} (use --force to override)") - else: - logger.debug(f"Copying to {dest_path}") - shutil.copy(src_path, dest_path) - - -def _are_identical_files(src_path: str, dest_path: str) -> bool: - src_crc32 = _compute_crc32(src_path) - dest_crc32 = _compute_crc32(dest_path) - return src_crc32 == dest_crc32 - - -def _compute_crc32(path: str) -> int: - if os.path.isdir(path): - return 0 - - size = 1024 * 1024 * 10 # 10 MiB chunks - with open(path, "rb") as fp: - crcval = 0 - while chunk := fp.read(size): - crcval = binascii.crc32(chunk, crcval) - return crcval - - -def get_vocab_id_from_config(config_path: str) -> str: - """Get the vocabulary ID from a configuration file.""" - config = configparser.ConfigParser() - config.read(config_path) - section = config.sections()[0] - return config[section]["vocab"] - - def _get_completion_choices( param: Argument, ) -> dict[str, AnnifVocabulary] | dict[str, AnnifProject] | list: diff --git a/annif/hfh_util.py b/annif/hfh_util.py new file mode 100644 index 000000000..045e4710f --- /dev/null +++ b/annif/hfh_util.py @@ -0,0 +1,240 @@ +"""Utility functions for interactions with Hugging Face Hub.""" + +import binascii +import configparser +import importlib +import io +import os +import pathlib +import shutil +import tempfile +import time +import zipfile +from fnmatch import fnmatch +from typing import Any + +import click +from flask import current_app + +import annif +from annif.exception import OperationFailedException +from annif.project import Access, AnnifProject + +logger = annif.logger + + +def get_matching_projects(pattern: str) -> list[AnnifProject]: + """ + Get projects that match the given pattern. + """ + return [ + proj + for proj in annif.registry.get_projects(min_access=Access.private).values() + if fnmatch(proj.project_id, pattern) + ] + + +def prepare_commits(projects: list[AnnifProject], repo_id: str) -> tuple[list, list]: + """Prepare and pre-upload data and config commit operations for projects to a + Hugging Face Hub repository.""" + from huggingface_hub import preupload_lfs_files + + fobjs, operations = [], [] + data_dirs = {p.datadir for p in projects} + vocab_dirs = {p.vocab.datadir for p in projects} + all_dirs = data_dirs.union(vocab_dirs) + + for data_dir in all_dirs: + fobj, operation = _prepare_datadir_commit(data_dir) + preupload_lfs_files(repo_id, additions=[operation]) + fobjs.append(fobj) + operations.append(operation) + + for project in projects: + fobj, operation = _prepare_config_commit(project) + fobjs.append(fobj) + operations.append(operation) + + return fobjs, operations + + +def _prepare_datadir_commit(data_dir: str) -> tuple[io.BufferedRandom, Any]: + from huggingface_hub import CommitOperationAdd + + zip_repo_path = data_dir.split(os.path.sep, 1)[1] + ".zip" + fobj = _archive_dir(data_dir) + operation = CommitOperationAdd(path_in_repo=zip_repo_path, path_or_fileobj=fobj) + return fobj, operation + + +def _prepare_config_commit(project: AnnifProject) -> tuple[io.BytesIO, Any]: + from huggingface_hub import CommitOperationAdd + + config_repo_path = project.project_id + ".cfg" + fobj = _get_project_config(project) + operation = CommitOperationAdd(path_in_repo=config_repo_path, path_or_fileobj=fobj) + return fobj, operation + + +def _is_train_file(fname: str) -> bool: + train_file_patterns = ("-train", "tmp-") + for pat in train_file_patterns: + if pat in fname: + return True + return False + + +def _archive_dir(data_dir: str) -> io.BufferedRandom: + fp = tempfile.TemporaryFile() + path = pathlib.Path(data_dir) + fpaths = [fpath for fpath in path.glob("**/*") if not _is_train_file(fpath.name)] + with zipfile.ZipFile(fp, mode="w") as zfile: + zfile.comment = bytes( + f"Archived by Annif {importlib.metadata.version('annif')}", + encoding="utf-8", + ) + for fpath in fpaths: + logger.debug(f"Adding {fpath}") + arcname = os.path.join(*fpath.parts[1:]) + zfile.write(fpath, arcname=arcname) + fp.seek(0) + return fp + + +def _get_project_config(project: AnnifProject) -> io.BytesIO: + fp = tempfile.TemporaryFile(mode="w+t") + config = configparser.ConfigParser() + config[project.project_id] = project.config + config.write(fp) # This needs tempfile in text mode + fp.seek(0) + # But for upload fobj needs to be in binary mode + return io.BytesIO(fp.read().encode("utf8")) + + +def get_matching_project_ids_from_hf_hub( + project_ids_pattern: str, repo_id: str, token, revision: str +) -> list[str]: + """Get project IDs of the projects in a Hugging Face Model Hub repository that match + the given pattern.""" + all_repo_file_paths = _list_files_in_hf_hub(repo_id, token, revision) + return [ + path.rsplit(".cfg")[0] + for path in all_repo_file_paths + if fnmatch(path, f"{project_ids_pattern}.cfg") + ] + + +def _list_files_in_hf_hub(repo_id: str, token: str, revision: str) -> list[str]: + from huggingface_hub import list_repo_files + from huggingface_hub.utils import HfHubHTTPError, HFValidationError + + try: + return [ + repofile + for repofile in list_repo_files( + repo_id=repo_id, token=token, revision=revision + ) + ] + except (HfHubHTTPError, HFValidationError) as err: + raise OperationFailedException(str(err)) + + +def download_from_hf_hub( + filename: str, repo_id: str, token: str, revision: str +) -> list[str]: + from huggingface_hub import hf_hub_download + from huggingface_hub.utils import HfHubHTTPError, HFValidationError + + try: + return hf_hub_download( + repo_id=repo_id, + filename=filename, + token=token, + revision=revision, + ) + except (HfHubHTTPError, HFValidationError) as err: + raise OperationFailedException(str(err)) + + +def unzip_archive(src_path: str, force: bool) -> None: + """Unzip a zip archive of projects and vocabularies to a directory, by + default data/ under current directory.""" + datadir = current_app.config["DATADIR"] + with zipfile.ZipFile(src_path, "r") as zfile: + archive_comment = str(zfile.comment, encoding="utf-8") + logger.debug( + f'Extracting archive {src_path}; archive comment: "{archive_comment}"' + ) + for member in zfile.infolist(): + _unzip_member(zfile, member, datadir, force) + + +def _unzip_member( + zfile: zipfile.ZipFile, member: zipfile.ZipInfo, datadir: str, force: bool +) -> None: + dest_path = os.path.join(datadir, member.filename) + if os.path.exists(dest_path) and not force: + _handle_existing_file(member, dest_path) + return + logger.debug(f"Unzipping to {dest_path}") + zfile.extract(member, path=datadir) + _restore_timestamps(member, dest_path) + + +def _handle_existing_file(member: zipfile.ZipInfo, dest_path: str) -> None: + if _are_identical_member_and_file(member, dest_path): + logger.debug(f"Skipping unzip to {dest_path}; already in place") + else: + click.echo(f"Not overwriting {dest_path} (use --force to override)") + + +def _are_identical_member_and_file(member: zipfile.ZipInfo, dest_path: str) -> bool: + path_crc = _compute_crc32(dest_path) + return path_crc == member.CRC + + +def _restore_timestamps(member: zipfile.ZipInfo, dest_path: str) -> None: + date_time = time.mktime(member.date_time + (0, 0, -1)) + os.utime(dest_path, (date_time, date_time)) + + +def copy_project_config(src_path: str, force: bool) -> None: + """Copy a given project configuration file to projects.d/ directory.""" + project_configs_dest_dir = "projects.d" + os.makedirs(project_configs_dest_dir, exist_ok=True) + + dest_path = os.path.join(project_configs_dest_dir, os.path.basename(src_path)) + if os.path.exists(dest_path) and not force: + if _are_identical_files(src_path, dest_path): + logger.debug(f"Skipping copy to {dest_path}; already in place") + else: + click.echo(f"Not overwriting {dest_path} (use --force to override)") + else: + logger.debug(f"Copying to {dest_path}") + shutil.copy(src_path, dest_path) + + +def _are_identical_files(src_path: str, dest_path: str) -> bool: + src_crc32 = _compute_crc32(src_path) + dest_crc32 = _compute_crc32(dest_path) + return src_crc32 == dest_crc32 + + +def _compute_crc32(path: str) -> int: + if os.path.isdir(path): + return 0 + + size = 1024 * 1024 * 10 # 10 MiB chunks + with open(path, "rb") as fp: + crcval = 0 + while chunk := fp.read(size): + crcval = binascii.crc32(chunk, crcval) + return crcval + + +def get_vocab_id_from_config(config_path: str) -> str: + """Get the vocabulary ID from a configuration file.""" + config = configparser.ConfigParser() + config.read(config_path) + section = config.sections()[0] + return config[section]["vocab"] diff --git a/tests/test_cli.py b/tests/test_cli.py index bb2a8b53d..c9d904ff2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -18,6 +18,7 @@ import annif.cli import annif.cli_util +import annif.hfh_util import annif.parallel runner = CliRunner(env={"ANNIF_CONFIG": "annif.default_config.TestingConfig"}) @@ -1140,7 +1141,7 @@ def test_archive_dir(testdatadir): open(os.path.join(str(dirpath), "foo.txt"), "a").close() open(os.path.join(str(dirpath), "-train.txt"), "a").close() - fobj = annif.cli_util._archive_dir(dirpath) + fobj = annif.hfh_util._archive_dir(dirpath) assert isinstance(fobj, io.BufferedRandom) with zipfile.ZipFile(fobj, mode="r") as zfile: @@ -1150,7 +1151,7 @@ def test_archive_dir(testdatadir): def test_get_project_config(app_project): - result = annif.cli_util._get_project_config(app_project) + result = annif.hfh_util._get_project_config(app_project) assert isinstance(result, io.BytesIO) string_result = result.read().decode("UTF-8") assert "[dummy-en]" in string_result @@ -1175,7 +1176,7 @@ def hf_hub_download_mock_side_effect(filename, repo_id, token, revision): "huggingface_hub.hf_hub_download", side_effect=hf_hub_download_mock_side_effect, ) -@mock.patch("annif.cli_util.copy_project_config") +@mock.patch("annif.hfh_util.copy_project_config") def test_download_dummy_fi( copy_project_config, hf_hub_download, list_repo_files, testdatadir ): @@ -1233,7 +1234,7 @@ def test_download_dummy_fi( "huggingface_hub.hf_hub_download", side_effect=hf_hub_download_mock_side_effect, ) -@mock.patch("annif.cli_util.copy_project_config") +@mock.patch("annif.hfh_util.copy_project_config") def test_download_dummy_fi_and_en( copy_project_config, hf_hub_download, list_repo_files, testdatadir ): @@ -1352,7 +1353,7 @@ def test_download_hf_hub_download_failed( def test_unzip_archive_initial(testdatadir): dirpath = os.path.join(str(testdatadir), "projects", "dummy-fi") fpath = os.path.join(str(dirpath), "file.txt") - annif.cli_util.unzip_archive( + annif.hfh_util.unzip_archive( os.path.join("tests", "huggingface-cache", "projects", "dummy-fi.zip"), force=False, ) @@ -1371,7 +1372,7 @@ def test_unzip_archive_no_overwrite(testdatadir): with open(fpath, "wt") as pf: print("Existing content", file=pf) - annif.cli_util.unzip_archive( + annif.hfh_util.unzip_archive( os.path.join("tests", "huggingface-cache", "projects", "dummy-fi.zip"), force=False, ) @@ -1387,7 +1388,7 @@ def test_unzip_archive_overwrite(testdatadir): with open(fpath, "wt") as pf: print("Existing content", file=pf) - annif.cli_util.unzip_archive( + annif.hfh_util.unzip_archive( os.path.join("tests", "huggingface-cache", "projects", "dummy-fi.zip"), force=True, ) @@ -1400,10 +1401,10 @@ def test_unzip_archive_overwrite(testdatadir): @mock.patch("os.path.exists", return_value=True) -@mock.patch("annif.cli_util._compute_crc32", return_value=0) +@mock.patch("annif.hfh_util._compute_crc32", return_value=0) @mock.patch("shutil.copy") def test_copy_project_config_no_overwrite(copy, _compute_crc32, exists): - annif.cli_util.copy_project_config( + annif.hfh_util.copy_project_config( os.path.join("tests", "huggingface-cache", "dummy-fi.cfg"), force=False ) assert not copy.called @@ -1412,7 +1413,7 @@ def test_copy_project_config_no_overwrite(copy, _compute_crc32, exists): @mock.patch("os.path.exists", return_value=True) @mock.patch("shutil.copy") def test_copy_project_config_overwrite(copy, exists): - annif.cli_util.copy_project_config( + annif.hfh_util.copy_project_config( os.path.join("tests", "huggingface-cache", "dummy-fi.cfg"), force=True ) assert copy.called