Skip to content

Commit

Permalink
Merge pull request #762 from NatLibFi/issue760-hugging-face-hub-integ…
Browse files Browse the repository at this point in the history
…ration

Implement `annif upload` and `annif download` commands for Hugging Face Hub integration
  • Loading branch information
juhoinkinen authored Apr 23, 2024
2 parents c3a86a6 + 6f35fff commit d9f3793
Show file tree
Hide file tree
Showing 13 changed files with 756 additions and 6 deletions.
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ RUN annif completion --bash >> /etc/bash.bashrc # Enable tab completion
RUN groupadd -g 998 annif_user && \
useradd -r -u 998 -g annif_user annif_user && \
chmod -R a+rX /Annif && \
mkdir -p /Annif/tests/data && \
mkdir -p /Annif/tests/data /Annif/projects.d && \
chown -R annif_user:annif_user /annif-projects /Annif/tests/data
USER annif_user
ENV HF_HOME="/tmp"

ENV GUNICORN_CMD_ARGS="--worker-class uvicorn.workers.UvicornWorker"

Expand Down
126 changes: 124 additions & 2 deletions annif/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
import annif.parallel
import annif.project
import annif.registry
from annif import cli_util
from annif.exception import NotInitializedException, NotSupportedException
from annif import cli_util, hfh_util
from annif.exception import (
NotInitializedException,
NotSupportedException,
OperationFailedException,
)
from annif.project import Access
from annif.util import metric_code

Expand Down Expand Up @@ -594,6 +598,124 @@ def run_hyperopt(project_id, paths, docs_limit, trials, jobs, metric, results_fi
click.echo("---")


@cli.command("upload")
@click.argument("project_ids_pattern", shell_complete=cli_util.complete_param)
@click.argument("repo_id")
@click.option(
"--token",
help="""Authentication token, obtained from the Hugging Face Hub.
Will default to the stored token.""",
)
@click.option(
"--revision",
help="""An optional git revision to commit from. Defaults to the head of the "main"
branch.""",
)
@click.option(
"--commit-message",
help="""The summary / title / first line of the generated commit.""",
)
@cli_util.common_options
def run_upload(project_ids_pattern, repo_id, token, revision, commit_message):
"""
Upload selected projects and their vocabularies to a Hugging Face Hub repository.
\f
This command zips the project directories and vocabularies of the projects
that match the given `project_ids_pattern` to archive files, and uploads the
archives along with the project configurations to the specified Hugging Face
Hub repository. An authentication token and commit message can be given with
options.
"""
from huggingface_hub import HfApi
from huggingface_hub.utils import HfHubHTTPError, HFValidationError

projects = hfh_util.get_matching_projects(project_ids_pattern)
click.echo(f"Uploading project(s): {', '.join([p.project_id for p in projects])}")

commit_message = (
commit_message
if commit_message is not None
else f"Upload project(s) {project_ids_pattern} with Annif"
)

fobjs, operations = [], []
try:
fobjs, operations = hfh_util.prepare_commits(projects, repo_id)
api = HfApi()
api.create_commit(
repo_id=repo_id,
operations=operations,
commit_message=commit_message,
revision=revision,
token=token,
)
except (HfHubHTTPError, HFValidationError) as err:
raise OperationFailedException(str(err))
finally:
for fobj in fobjs:
fobj.close()


@cli.command("download")
@click.argument("project_ids_pattern")
@click.argument("repo_id")
@click.option(
"--token",
help="""Authentication token, obtained from the Hugging Face Hub.
Will default to the stored token.""",
)
@click.option(
"--revision",
help="""
An optional Git revision id which can be a branch name, a tag, or a commit
hash.
""",
)
@click.option(
"--force",
"-f",
default=False,
is_flag=True,
help="Replace an existing project/vocabulary/config with the downloaded one",
)
@cli_util.common_options
def run_download(project_ids_pattern, repo_id, token, revision, force):
"""
Download selected projects and their vocabularies from a Hugging Face Hub
repository.
\f
This command downloads the project and vocabulary archives and the
configuration files of the projects that match the given
`project_ids_pattern` from the specified Hugging Face Hub repository and
unzips the archives to `data/` directory and places the configuration files
to `projects.d/` directory. An authentication token and revision can
be given with options.
"""

project_ids = hfh_util.get_matching_project_ids_from_hf_hub(
project_ids_pattern, repo_id, token, revision
)
click.echo(f"Downloading project(s): {', '.join(project_ids)}")

vocab_ids = set()
for project_id in project_ids:
project_zip_cache_path = hfh_util.download_from_hf_hub(
f"projects/{project_id}.zip", repo_id, token, revision
)
hfh_util.unzip_archive(project_zip_cache_path, force)
config_file_cache_path = hfh_util.download_from_hf_hub(
f"{project_id}.cfg", repo_id, token, revision
)
vocab_ids.add(hfh_util.get_vocab_id_from_config(config_file_cache_path))
hfh_util.copy_project_config(config_file_cache_path, force)

for vocab_id in vocab_ids:
vocab_zip_cache_path = hfh_util.download_from_hf_hub(
f"vocabs/{vocab_id}.zip", repo_id, token, revision
)
hfh_util.unzip_archive(vocab_zip_cache_path, force)


@cli.command("completion")
@click.option("--bash", "shell", flag_value="bash")
@click.option("--zsh", "shell", flag_value="zsh")
Expand Down
6 changes: 3 additions & 3 deletions annif/cli_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from annif.project import Access

if TYPE_CHECKING:
import io
from datetime import datetime
from io import TextIOWrapper

from click.core import Argument, Context, Option

Expand Down Expand Up @@ -185,7 +185,7 @@ def show_hits(
hits: SuggestionResult,
project: AnnifProject,
lang: str,
file: TextIOWrapper | None = None,
file: io.TextIOWrapper | None = None,
) -> None:
"""
Print subject suggestions to the console or a file. The suggestions are displayed as
Expand Down Expand Up @@ -234,7 +234,7 @@ def generate_filter_params(filter_batch_max_limit: int) -> list[tuple[int, float
def _get_completion_choices(
param: Argument,
) -> dict[str, AnnifVocabulary] | dict[str, AnnifProject] | list:
if param.name == "project_id":
if param.name in ("project_id", "project_ids_pattern"):
return annif.registry.get_projects()
elif param.name == "vocab_id":
return annif.registry.get_vocabs()
Expand Down
Loading

0 comments on commit d9f3793

Please sign in to comment.