Skip to content

Commit

Permalink
Merge pull request #991 from gchq/bugfix/python-fix-file-download
Browse files Browse the repository at this point in the history
Remove context manager and add testing for file downloads
  • Loading branch information
GB27247 authored Jan 4, 2024
2 parents 4fb6802 + d9905f6 commit a9ca366
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 11 deletions.
2 changes: 1 addition & 1 deletion lib/python-beta/src/bailo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

__version__ = "0.2.2"
__version__ = "0.2.3"

from bailo.core.agent import Agent
from bailo.core.client import Client
Expand Down
7 changes: 2 additions & 5 deletions lib/python-beta/src/bailo/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,8 @@ def get_download_file(
:param buffer: BytesIO object for bailo to write to
:return: The unique file ID
"""
with self.agent.get(
f"{self.url}/v2/model/{model_id}/file/{file_id}/download", stream=True, timeout=10_000
) as req:
with buffer as file:
shutil.copyfileobj(req.raw, file)
req = self.agent.get(f"{self.url}/v2/model/{model_id}/file/{file_id}/download", stream=True, timeout=10_000)
shutil.copyfileobj(req.raw, buffer)
return file_id

def simple_upload(self, model_id: str, name: str, buffer: BytesIO):
Expand Down
2 changes: 1 addition & 1 deletion lib/python-beta/src/bailo/helper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def update(self) -> None:
)
self.__unpack(res["model"])

def update_model_card(self, model_card: dict[str, Any] = None) -> None:
def update_model_card(self, model_card: dict[str, Any] | None = None) -> None:
"""Uploads and retrieves any changes to the model card on Bailo
:param model_card: Model card dictionary, defaults to None
Expand Down
4 changes: 2 additions & 2 deletions lib/python-beta/src/bailo/helper/release.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ def upload(self, name: str, file: BytesIO) -> Any:
:param name: The name of the file to upload to bailo
:param f: A BytesIO object
:return: A JSON response object
:return: The unique file ID of the file uploaded
"""
res = self.client.simple_upload(self.model_id, name, file).json()
self.files.append(res["file"]["id"])
self.update()
return res
return res["file"]["id"]

def update(self) -> Any:
"""Updates the any changes to this release on Bailo.
Expand Down
3 changes: 1 addition & 2 deletions lib/python-beta/tests/test_access_request.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import pytest
from bailo.core.client import Client
from bailo.helper.access_request import AccessRequest
from bailo import AccessRequest, Client


def test_access_request():
Expand Down
24 changes: 24 additions & 0 deletions lib/python-beta/tests/test_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from io import BytesIO

import pytest
from bailo import Model


@pytest.mark.integration
def test_file_upload(integration_client, example_model: Model):
byte_obj = b"Test Binary"
file = BytesIO(byte_obj)

example_release = example_model.create_release("0.1.0", "test")

with file as file:
file_id = example_release.upload("test", file)

download_file = BytesIO()
example_release.download(file_id, download_file)

# Check that file uploaded has the same contents as the one downloaded
download_file.seek(0)
assert download_file.read() == byte_obj

0 comments on commit a9ca366

Please sign in to comment.