Skip to content

Commit

Permalink
update and run /lib/modelscan_api pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
PE39806 committed Jan 28, 2025
1 parent 4c13420 commit 7974ad4
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 52 deletions.
7 changes: 4 additions & 3 deletions lib/modelscan_api/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,20 @@ repos:
rev: 6.0.0
hooks:
- id: isort
args: ['-a', 'from __future__ import annotations']
args: ['-a', 'from __future__ import annotations', '-l', '120']

- repo: https://github.com/asottile/pyupgrade
rev: v3.19.1
hooks:
- id: pyupgrade
args: [--py37-plus]
args: [--py39-plus]

- repo: https://github.com/hadialqattan/pycln
rev: v2.5.0
hooks:
- id: pycln
args: [--config=pyproject.toml]
# use the same config as for /lib/python
args: [--config=../python/pyproject.toml]
stages: [manual]

- repo: https://github.com/codespell-project/codespell
Expand Down
2 changes: 1 addition & 1 deletion lib/modelscan_api/bailo_modelscan_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Settings(BaseSettings):
# Update frontend/pages/docs/administration/helm/configuration.mdx if bumping this.
app_version: str = "1.0.0"
# download_dir is used if it evaluates, otherwise a temporary directory is used.
download_dir: Optional[str] = None
download_dir: str | None = None
modelscan_settings: dict[str, Any] = DEFAULT_SETTINGS
block_size: int = 1024

Expand Down
16 changes: 5 additions & 11 deletions lib/modelscan_api/bailo_modelscan_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from modelscan.modelscan import ModelScan
from pydantic import BaseModel

# isort: split

from bailo_modelscan_api.config import Settings
from bailo_modelscan_api.dependencies import safe_join

Expand Down Expand Up @@ -225,9 +227,7 @@ async def info(settings: Annotated[Settings, Depends(get_settings)]) -> ApiInfor
"description": "The server could not complete the request",
"content": {
"application/json": {
"example": {
"detail": "An error occurred while processing the uploaded file's name."
}
"example": {"detail": "An error occurred while processing the uploaded file's name."}
}
},
},
Expand All @@ -251,11 +251,7 @@ async def scan_file(
modelscan_model = ModelScan(settings=settings.modelscan_settings)

# Use Setting's download_dir if defined else use a temporary directory.
with (
TemporaryDirectory()
if not settings.download_dir
else nullcontext(settings.download_dir)
) as download_dir:
with TemporaryDirectory() if not settings.download_dir else nullcontext(settings.download_dir) as download_dir:
if in_file.filename and str(in_file.filename).strip():
# Prevent escaping to a parent dir
try:
Expand Down Expand Up @@ -315,9 +311,7 @@ async def scan_file(
background_tasks.add_task(Path.unlink, pathlib_path, missing_ok=True)
except UnboundLocalError:
# pathlib_path may not exist.
logger.exception(
"An error occurred while trying to cleanup the downloaded file."
)
logger.exception("An error occurred while trying to cleanup the downloaded file.")


if __name__ == "__main__":
Expand Down
17 changes: 7 additions & 10 deletions lib/modelscan_api/tests/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,20 @@
from __future__ import annotations

import itertools
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Iterable
from typing import Any

import pytest

from bailo_modelscan_api.dependencies import (
parse_path,
safe_join,
sanitise_unix_filename,
)
# isort: split

from bailo_modelscan_api.dependencies import parse_path, safe_join, sanitise_unix_filename

# Helpers


def type_matrix(
data: Iterable[Any], types: Iterable[type]
) -> itertools.product[tuple[Any, ...]]:
def type_matrix(data: Iterable[Any], types: Iterable[type]) -> itertools.product[tuple[Any, ...]]:
"""Generate a matrix of all combinations of `data` converted to each type in `types`.
For example:
`list(type_matrix(["foo", "bar"], [str, Path])) -> [(str(foo), str(bar)), (str(foo), Path(bar)), (Path(foo), str(bar)), (Path(foo), Path(bar))]`
Expand Down Expand Up @@ -55,7 +52,7 @@ def type_matrix(
"".join(['\\[/\\?%*:|"<>0x7F0x00-0x1F]', chr(0x1F) * 15]),
"-[----------0x7F0x00-0x1F]---------------",
),
("ad\nbla'{-+\\)(ç?", "ad-bla'{-+-)(ç-"), # type: ignore
("ad\nbla'{-+\\)(ç?", "ad-bla'{-+-)(ç-"),
],
)
def test_sanitise_unix_filename(path: str, output: str) -> None:
Expand Down
2 changes: 2 additions & 0 deletions lib/modelscan_api/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import pytest
from fastapi.testclient import TestClient

# isort: split

from bailo_modelscan_api.config import Settings
from bailo_modelscan_api.main import app, get_settings

Expand Down
16 changes: 8 additions & 8 deletions lib/modelscan_api/tests/test_integration/generate_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ def __init__(self, file, protocol, inj_objs):

def dump(self, obj):
"Pickle data, inject object before or after"
if self.proto >= 2: # type: ignore
self.write(pickle.PROTO + struct.pack("<B", self.proto)) # type: ignore
if self.proto >= 4: # type: ignore
self.framer.start_framing() # type: ignore
if self.proto >= 2:
self.write(pickle.PROTO + struct.pack("<B", self.proto))
if self.proto >= 4:
self.framer.start_framing()
for inj_obj in self.inj_objs:
self.save(inj_obj) # type: ignore
self.save(obj) # type: ignore
self.write(pickle.STOP) # type: ignore
self.framer.end_framing() # type: ignore
self.save(inj_obj)
self.save(obj)
self.write(pickle.STOP)
self.framer.end_framing()


class _PickleInject:
Expand Down
26 changes: 8 additions & 18 deletions lib/modelscan_api/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import pytest
from fastapi.testclient import TestClient

# isort: split

from bailo_modelscan_api.config import Settings
from bailo_modelscan_api.dependencies import parse_path
from bailo_modelscan_api.main import app, get_settings
Expand Down Expand Up @@ -52,9 +54,7 @@ def test_info():
("-", EMPTY_CONTENTS, H5_MIME_TYPE),
],
)
def test_scan_file(
mock_scan: Mock, file_name: str, file_content: Any, file_mime_type: str
):
def test_scan_file(mock_scan: Mock, file_name: str, file_content: Any, file_mime_type: str):
mock_scan.return_value = {}
files = {"in_file": (file_name, file_content, file_mime_type)}

Expand All @@ -68,27 +68,21 @@ def test_scan_file(
("file_name", "file_content", "file_mime_type"),
[("..", EMPTY_CONTENTS, H5_MIME_TYPE)],
)
def test_scan_file_escape_path_error(
file_name: str, file_content: Any, file_mime_type: str
):
def test_scan_file_escape_path_error(file_name: str, file_content: Any, file_mime_type: str):
files = {"in_file": (file_name, file_content, file_mime_type)}

response = client.post("/scan/file", files=files)

assert response.status_code == 500
assert response.json() == {
"detail": "An error occurred while processing the uploaded file's name."
}
assert response.json() == {"detail": "An error occurred while processing the uploaded file's name."}


@patch("modelscan.modelscan.ModelScan.scan")
@pytest.mark.parametrize(
("file_name", "file_content", "file_mime_type"),
[("foo.h5", EMPTY_CONTENTS, H5_MIME_TYPE)],
)
def test_scan_file_exception(
mock_scan: Mock, file_name: str, file_content: Any, file_mime_type: str
):
def test_scan_file_exception(mock_scan: Mock, file_name: str, file_content: Any, file_mime_type: str):
mock_scan.side_effect = Exception("Mocked error!")
files = {"in_file": (file_name, file_content, file_mime_type)}

Expand All @@ -109,14 +103,10 @@ def test_scan_file_exception(
("file_name", "file_content", "file_mime_type"),
[(" ", EMPTY_CONTENTS, H5_MIME_TYPE)],
)
def test_scan_file_filename_missing(
file_name: str, file_content: Any, file_mime_type: str
):
def test_scan_file_filename_missing(file_name: str, file_content: Any, file_mime_type: str):
files = {"in_file": (file_name, file_content, file_mime_type)}

response = client.post("/scan/file", files=files)

assert response.status_code == 500
assert response.json() == {
"detail": "An error occurred while extracting the uploaded file's name."
}
assert response.json() == {"detail": "An error occurred while extracting the uploaded file's name."}
2 changes: 1 addition & 1 deletion lib/python/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
rev: 6.0.0
hooks:
- id: isort
args: ['-a', 'from __future__ import annotations']
args: ['-a', 'from __future__ import annotations', '-l', '120']

- repo: https://github.com/asottile/pyupgrade
rev: v3.19.1
Expand Down

0 comments on commit 7974ad4

Please sign in to comment.