Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: fix some minor typing problem #527

Merged
merged 7 commits into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,12 @@ ignore = [
]

[tool.pyright]
root = '.'
include = ['./transmission_rpc/']
exclude = ['./tests/', './.venv/', './docs/']
ignore = ['./tests/', './docs/']
pythonVersion = '3.8'
pythonPlatform = 'Linux'
typeCheckingMode = "standard"
typeCheckingMode = "strict"
# reportUnnecessaryComparison = false
reportUnnecessaryIsInstance = false
reportUnknownVariableType = false
3 changes: 1 addition & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from typing_extensions import Literal

from tests.util import ServerTooLowError, skip_on
from transmission_rpc.client import Client, ensure_location_str
from transmission_rpc.client import Client, _try_read_torrent, ensure_location_str
from transmission_rpc.error import TransmissionAuthError
from transmission_rpc.types import File
from transmission_rpc.utils import _try_read_torrent


@pytest.mark.parametrize(
Expand Down
38 changes: 0 additions & 38 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,6 @@ def assert_almost_eq(value: float, expected: float):
assert abs(value - expected) < 1


@pytest.mark.parametrize(
("size", "expected"),
{
512: (512, "B"),
1024: (1.0, "KiB"),
1048575: (1023.999, "KiB"),
1048576: (1.0, "MiB"),
1073741824: (1.0, "GiB"),
1099511627776: (1.0, "TiB"),
1125899906842624: (1.0, "PiB"),
1152921504606846976: (1.0, "EiB"),
}.items(),
)
def test_format_size(size, expected: tuple[float, str]):
result = utils.format_size(size)
assert_almost_eq(result[0], expected[0])
assert result[1] == expected[1]


@pytest.mark.parametrize(
("size", "expected"),
[
(512, (512, "B/s")),
(1024, (1.0, "KiB/s")),
(1048575, (1023.999, "KiB/s")),
(1048576, (1.0, "MiB/s")),
(1073741824, (1.0, "GiB/s")),
(1099511627776, (1.0, "TiB/s")),
(1125899906842624, (1.0, "PiB/s")),
(1152921504606846976, (1.0, "EiB/s")),
],
)
def test_format_speed(size, expected):
result = utils.format_speed(size)
assert_almost_eq(result[0], expected[0])
assert result[1] == expected[1]


@pytest.mark.parametrize(
("delta", "expected"),
{
Expand Down
10 changes: 4 additions & 6 deletions transmission_rpc/_unix_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

from urllib3.connection import HTTPConnection
from urllib3.connectionpool import HTTPConnectionPool
from urllib3.util.connection import _TYPE_SOCKET_OPTIONS
from urllib3.util.timeout import _DEFAULT_TIMEOUT


class UnixHTTPConnection(HTTPConnection):
Expand All @@ -20,7 +18,7 @@
host: str,
*,
# The default socket options include `TCP_NODELAY` which won't work here.
socket_options: None | _TYPE_SOCKET_OPTIONS = None,
socket_options: None | list[tuple[int, int, int | bytes]] = None,
**kwargs: Any,
):
self.socket_path = host
Expand All @@ -37,10 +35,10 @@

socket_options = self.socket_options
if socket_options is not None:
for opt in socket_options:
sock.setsockopt(*opt)
for lvl, opt, value in socket_options:
sock.setsockopt(lvl, opt, value)

Check warning on line 39 in transmission_rpc/_unix_socket.py

View check run for this annotation

Codecov / codecov/patch

transmission_rpc/_unix_socket.py#L38-L39

Added lines #L38 - L39 were not covered by tests

if self.timeout is not _DEFAULT_TIMEOUT: # type: ignore
if self.timeout is not None:

Check warning on line 41 in transmission_rpc/_unix_socket.py

View check run for this annotation

Codecov / codecov/patch

transmission_rpc/_unix_socket.py#L41

Added line #L41 was not covered by tests
sock.settimeout(self.timeout)
sock.connect(self.socket_path)
return sock
Expand Down
56 changes: 34 additions & 22 deletions transmission_rpc/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import base64
import importlib.metadata
import json
import logging
Expand All @@ -8,6 +9,7 @@
import time
import types
from typing import Any, BinaryIO, Iterable, List, TypeVar, Union
from urllib.parse import urlparse

import certifi
import urllib3
Expand All @@ -16,7 +18,7 @@
from urllib3.util import make_headers

from transmission_rpc._unix_socket import UnixHTTPConnectionPool
from transmission_rpc.constants import LOGGER, RpcMethod
from transmission_rpc.constants import LOGGER, RpcMethod, get_torrent_arguments
from transmission_rpc.error import (
TransmissionAuthError,
TransmissionConnectError,
Expand All @@ -26,7 +28,6 @@
from transmission_rpc.session import Session, SessionStats
from transmission_rpc.torrent import Torrent
from transmission_rpc.types import Group, PortTestResult
from transmission_rpc.utils import _try_read_torrent, get_torrent_arguments

try:
__version__ = importlib.metadata.version("transmission-rpc")
Expand Down Expand Up @@ -66,7 +67,7 @@
return str(s)


def _parse_torrent_id(raw_torrent_id: int | str) -> int | str:
def _parse_torrent_id(raw_torrent_id: Any) -> int | str:
if isinstance(raw_torrent_id, int):
if raw_torrent_id >= 0:
return raw_torrent_id
Expand Down Expand Up @@ -327,10 +328,10 @@

res = data["arguments"]

results = {}
if method == RpcMethod.TorrentGet:
return res
if method == RpcMethod.TorrentAdd:
results: dict[str, Any] = {}
item = None
if "torrent-added" in res:
item = res["torrent-added"]
Expand All @@ -346,24 +347,16 @@
response=data,
raw_response=http_data,
)
elif method == RpcMethod.SessionGet:
return results
if method == RpcMethod.SessionGet:
self.__raw_session.update(res)
elif method == RpcMethod.SessionStats:
if method == RpcMethod.SessionStats:
# older versions of T has the return data in "session-stats"
if "session-stats" in res:
return res["session-stats"]
return res
elif method in (
RpcMethod.PortTest,
RpcMethod.BlocklistUpdate,
RpcMethod.FreeSpace,
RpcMethod.TorrentRenamePath,
):
return res
else:
return res

return results
return res

def _update_server_version(self) -> None:
"""Decode the Transmission version string, if available."""
Expand Down Expand Up @@ -460,9 +453,6 @@
Array of string labels.
Add in rpc 17.
"""
if torrent is None:
raise ValueError("add_torrent requires data or a URI.")

if labels is not None:
self._rpc_version_warning(17)

Expand Down Expand Up @@ -581,8 +571,6 @@
else:
arguments = self.__torrent_get_arguments
torrent_id = _parse_torrent_id(torrent_id)
if torrent_id is None:
raise ValueError("Invalid id")

result = self._request(
RpcMethod.TorrentGet,
Expand Down Expand Up @@ -853,7 +841,7 @@
Get session parameters. See the Session class for more information.
"""

data = {}
data: dict[str, Any] = {}
if arguments:
data["fields"] = list(arguments)

Expand Down Expand Up @@ -1217,3 +1205,27 @@

def remove_unset_value(data: dict[str, Any]) -> dict[str, Any]:
return {key: value for key, value in data.items() if value is not None}


def _try_read_torrent(torrent: BinaryIO | str | bytes | pathlib.Path) -> str | None:
"""
if torrent should be encoded with base64, return a non-None value.
"""
# torrent is a str, may be a url
if isinstance(torrent, str):
parsed_uri = urlparse(torrent)
# torrent starts with file, read from local disk and encode it to base64 url.
if parsed_uri.scheme in ["https", "http", "magnet"]:
return None

if parsed_uri.scheme in ["file"]:
raise ValueError("support for `file://` URL has been removed.")

Check warning on line 1222 in transmission_rpc/client.py

View check run for this annotation

Codecov / codecov/patch

transmission_rpc/client.py#L1221-L1222

Added lines #L1221 - L1222 were not covered by tests
elif isinstance(torrent, pathlib.Path):
return base64.b64encode(torrent.read_bytes()).decode("utf-8")
elif isinstance(torrent, bytes):
return base64.b64encode(torrent).decode("utf-8")
# maybe a file, try read content and encode it.
elif hasattr(torrent, "read"):
return base64.b64encode(torrent.read()).decode("utf-8")

return None

Check warning on line 1231 in transmission_rpc/client.py

View check run for this annotation

Codecov / codecov/patch

transmission_rpc/client.py#L1231

Added line #L1231 was not covered by tests
16 changes: 16 additions & 0 deletions transmission_rpc/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,19 @@ class RpcMethod(str, enum.Enum):
PortTest = "port-test"

BlocklistUpdate = "blocklist-update"


def get_torrent_arguments(rpc_version: int) -> list[str]:
"""
Get torrent arguments for method in specified Transmission RPC version.
"""
accessible: list[str] = []
for argument, info in TORRENT_GET_ARGS.items():
valid_version = True
if rpc_version < info.added_version:
valid_version = False
if info.removed_version is not None and info.removed_version <= rpc_version:
valid_version = False
if valid_version:
accessible.append(argument)
return accessible
73 changes: 0 additions & 73 deletions transmission_rpc/utils.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,4 @@
# Copyright (c) 2018-2021 Trim21 <[email protected]>
# Copyright (c) 2008-2014 Erik Svensson <[email protected]>
# Licensed under the MIT license.
from __future__ import annotations

import base64
import datetime
import pathlib
from typing import BinaryIO
from urllib.parse import urlparse

from transmission_rpc import constants

UNITS = ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB"]


def format_size(size: int) -> tuple[float, str]:
"""
Format byte size into IEC prefixes, B, KiB, MiB ...
"""
s = float(size)
i = 0
while s >= 1024.0 and i < len(UNITS):
i += 1
s /= 1024.0
return s, UNITS[i]


def format_speed(size: int) -> tuple[float, str]:
"""
Format bytes per second speed into IEC prefixes, B/s, KiB/s, MiB/s ...
"""
(s, unit) = format_size(size)
return s, f"{unit}/s"


def format_timedelta(delta: datetime.timedelta) -> str:
Expand All @@ -41,43 +8,3 @@ def format_timedelta(delta: datetime.timedelta) -> str:
minutes, seconds = divmod(delta.seconds, 60)
hours, minutes = divmod(minutes, 60)
return f"{delta.days:d} {hours:02d}:{minutes:02d}:{seconds:02d}"


def get_torrent_arguments(rpc_version: int) -> list[str]:
"""
Get torrent arguments for method in specified Transmission RPC version.
"""
accessible = []
for argument, info in constants.TORRENT_GET_ARGS.items():
valid_version = True
if rpc_version < info.added_version:
valid_version = False
if info.removed_version is not None and info.removed_version <= rpc_version:
valid_version = False
if valid_version:
accessible.append(argument)
return accessible


def _try_read_torrent(torrent: BinaryIO | str | bytes | pathlib.Path) -> str | None:
"""
if torrent should be encoded with base64, return a non-None value.
"""
# torrent is a str, may be a url
if isinstance(torrent, str):
parsed_uri = urlparse(torrent)
# torrent starts with file, read from local disk and encode it to base64 url.
if parsed_uri.scheme in ["https", "http", "magnet"]:
return None

if parsed_uri.scheme in ["file"]:
raise ValueError("support for `file://` URL has been removed.")
elif isinstance(torrent, pathlib.Path):
return base64.b64encode(torrent.read_bytes()).decode("utf-8")
elif isinstance(torrent, bytes):
return base64.b64encode(torrent).decode("utf-8")
# maybe a file, try read content and encode it.
elif hasattr(torrent, "read"):
return base64.b64encode(torrent.read()).decode("utf-8")

return None
Loading