diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 289b5005..367ad1e4 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -12,6 +12,7 @@ jobs: test: runs-on: ubuntu-24.04 strategy: + fail-fast: false matrix: transmission: ["version-3.00-r8", "4.0.5"] python: ["3.8", "3.9", "3.10", "3.11", "3.12"] @@ -52,6 +53,33 @@ jobs: flags: "${{ matrix.python }}" token: ${{ secrets.CODECOV_TOKEN }} + test-unix-socket: + # At the time of writing ubuntu-latest is ubuntu-22.04. But we need + # an even later version because ubuntu-22.04 provides Transmission 3.0.0 but + # Unix socket support was added in Transmission 4.0.0. Use 24.04 which is + # currently in beta. + runs-on: ubuntu-24.04 + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 2 + - run: sudo apt-get install -y transmission-daemon + - run: mkdir -p $HOME/Downloads + - run: transmission-daemon --rpc-bind-address unix:/tmp/transmission.socket + - uses: actions/setup-python@v5 + with: + python-version: 3.9 # Oldest version available for ubuntu-24.04 + cache: pip + - run: pip install -e .[dev] + - run: coverage run -m pytest + env: + TR_PROTOCOL: 'http+unix' + TR_HOST: '/tmp/transmission.socket' + - uses: codecov/codecov-action@v4 + with: + flags: "unix-socket" + token: ${{ secrets.CODECOV_TOKEN }} + dist-files: runs-on: ubuntu-24.04 diff --git a/pyproject.toml b/pyproject.toml index dfe20954..bd5635ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,8 @@ classifiers = [ ] dependencies = [ - 'requests~=2.23', + 'urllib3~=2.2', + 'certifi>=2017.4.17', 'typing-extensions>=4.5.0', ] @@ -46,7 +47,6 @@ dev = [ 'coverage==7.6.1', # types 'mypy==1.13.0', - 'types-requests==2.32.0.20241016', # docs 'sphinx>=8,<=8.1.3; python_version >= "3.10"', 'furo==2024.8.6; python_version >= "3.10"', @@ -149,4 +149,5 @@ ignore = [ 'PLR0915', 'PLR2004', 'PGH003', + 'TCH002', ] diff --git a/tests/conftest.py b/tests/conftest.py index 00905c08..2720bbee 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,21 +1,40 @@ +# ruff: noqa: SIM117 +import contextlib import os import secrets +import socket +import time import pytest from transmission_rpc import LOGGER from transmission_rpc.client import Client +PROTOCOL = os.getenv("TR_PROTOCOL", "http") HOST = os.getenv("TR_HOST", "127.0.0.1") PORT = int(os.getenv("TR_PORT", "9091")) USER = os.getenv("TR_USER", "admin") PASSWORD = os.getenv("TR_PASSWORD", "password") +def pytest_configure(): + start = time.time() + while True: + with contextlib.suppress(ConnectionError, FileNotFoundError): + is_unix = PROTOCOL == "http+unix" + with socket.socket(socket.AF_UNIX if is_unix else socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(3) + sock.connect(HOST if is_unix else (HOST, PORT)) + break + + if time.time() - start > 30: + raise ConnectionError("timeout trying to connect to transmission-daemon, is transmission daemon started?") + + @pytest.fixture def tr_client(): LOGGER.setLevel("INFO") - with Client(host=HOST, port=PORT, username=USER, password=PASSWORD) as c: + with Client(protocol=PROTOCOL, host=HOST, port=PORT, username=USER, password=PASSWORD) as c: for torrent in c.get_torrents(): c.remove_torrent(torrent.id, delete_data=True) yield c diff --git a/tests/test_client.py b/tests/test_client.py index 6e7c54b0..37567e48 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -5,7 +5,6 @@ from urllib.parse import urljoin import pytest -import yarl from typing_extensions import Literal from tests.util import ServerTooLowError, skip_on @@ -48,18 +47,8 @@ def test_client_parse_url(protocol: Literal["http", "https"], username, password port=port, path=path, ) - u = str( - yarl.URL.build( - scheme=protocol, - user=username, - password=password, - host=host, - port=port, - path=urljoin(path, "rpc"), - ) - ) - assert client._url == u # noqa: SLF001 + assert client._url == f'{protocol}://{host}:{port}{urljoin(path, "rpc")}' # noqa: SLF001 def hash_to_magnet(h): @@ -222,8 +211,8 @@ def test_real_torrent_get_files(tr_client: Client): [401, 403], ) def test_raise_unauthorized(status_code): - m = mock.Mock(return_value=mock.Mock(status_code=status_code)) - with mock.patch("requests.Session.post", m), pytest.raises(TransmissionAuthError): + m = mock.Mock(return_value=mock.Mock(status=status_code)) + with mock.patch("urllib3.HTTPConnectionPool.request", m), pytest.raises(TransmissionAuthError): Client() diff --git a/tests/test_utils.py b/tests/test_utils.py index 97e48c75..5f21f0e4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,8 +9,8 @@ import pytest -from transmission_rpc import from_url, utils -from transmission_rpc.constants import DEFAULT_TIMEOUT, LOGGER +from transmission_rpc import DEFAULT_TIMEOUT, from_url, utils +from transmission_rpc.constants import LOGGER def assert_almost_eq(value: float, expected: float): @@ -106,6 +106,14 @@ def test_format_timedelta(delta, expected): "port": 443, "path": "/", }, + "http+unix://%2Fvar%2Frun%2Ftransmission.sock/transmission/rpc": { + "protocol": "http+unix", + "username": None, + "password": None, + "host": "/var/run/transmission.sock", + "port": None, + "path": "/transmission/rpc", + }, }.items(), ) def test_from_url(url: str, kwargs: dict[str, Any]): diff --git a/transmission_rpc/__init__.py b/transmission_rpc/__init__.py index a5106e23..28acf244 100644 --- a/transmission_rpc/__init__.py +++ b/transmission_rpc/__init__.py @@ -1,8 +1,8 @@ import logging import urllib.parse -from transmission_rpc.client import Client -from transmission_rpc.constants import DEFAULT_TIMEOUT, LOGGER, IdleMode, Priority, RatioLimitMode +from transmission_rpc.client import DEFAULT_TIMEOUT, Client +from transmission_rpc.constants import LOGGER, IdleMode, Priority, RatioLimitMode from transmission_rpc.error import ( TransmissionAuthError, TransmissionConnectError, @@ -50,6 +50,7 @@ def from_url( from_url("https://127.0.0.1/transmission/rpc") # https://127.0.0.1:443/transmission/rpc from_url("http://127.0.0.1") # http://127.0.0.1:80/transmission/rpc from_url("http://127.0.0.1/") # http://127.0.0.1:80/ + from_url("http+unix://%2Fvar%2Frun%2Ftransmission.sock/transmission/rpc") # /transmission/rpc on /var/run/transmission.sock Unix socket Warnings: you can't ignore scheme, ``127.0.0.1:9091`` is not valid url, please use ``http://127.0.0.1:9091`` @@ -61,10 +62,16 @@ def from_url( u = urllib.parse.urlparse(url) protocol = u.scheme + host = u.hostname + default_port = None if protocol == "http": default_port = 80 elif protocol == "https": default_port = 443 + elif protocol == "http+unix": + if host is None: + raise ValueError("http+unix URL is missing Unix socket path") + host = urllib.parse.unquote(host, errors="strict") else: raise ValueError(f"unknown url scheme {u.scheme}") @@ -72,7 +79,7 @@ def from_url( protocol=protocol, # type: ignore username=u.username, password=u.password, - host=u.hostname or "127.0.0.1", + host=host or "127.0.0.1", port=u.port or default_port, path=u.path or "/transmission/rpc", timeout=timeout, diff --git a/transmission_rpc/_unix_socket.py b/transmission_rpc/_unix_socket.py new file mode 100644 index 00000000..adaf8b6f --- /dev/null +++ b/transmission_rpc/_unix_socket.py @@ -0,0 +1,53 @@ +# Inspired from: +# https://github.com/getsentry/sentry/blob/9d03adef66f63e29a5d95189447d02ba0b68c2af/src/sentry/net/http.py#L215-L244 +# See also: +# https://github.com/urllib3/urllib3/issues/1465 + +from __future__ import annotations + +import socket +from typing import Any + +from urllib3.connection import HTTPConnection +from urllib3.connectionpool import HTTPConnectionPool +from urllib3.util.connection import _TYPE_SOCKET_OPTIONS +from urllib3.util.timeout import _DEFAULT_TIMEOUT + + +class UnixHTTPConnection(HTTPConnection): + def __init__( + self, + host: str, + *, + # The default socket options include `TCP_NODELAY` which won't work here. + socket_options: None | _TYPE_SOCKET_OPTIONS = None, + **kwargs: Any, + ): + self.socket_path = host + # We're using the `host` as the socket path, but + # urllib3 uses this host as the Host header by default. + # If we send along the socket path as a Host header, this is + # never what you want and would typically be malformed value. + # So we fake this by sending along `localhost` by default as + # other libraries do. + super().__init__(host="localhost", socket_options=socket_options, **kwargs) + + def _new_conn(self) -> socket.socket: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + socket_options = self.socket_options + if socket_options is not None: + for opt in socket_options: + sock.setsockopt(*opt) + + if self.timeout is not _DEFAULT_TIMEOUT: # type: ignore + sock.settimeout(self.timeout) + sock.connect(self.socket_path) + return sock + + +class UnixHTTPConnectionPool(HTTPConnectionPool): + ConnectionCls = UnixHTTPConnection + + def __str__(self) -> str: + return f"{type(self).__name__}(host={self.host})" diff --git a/transmission_rpc/client.py b/transmission_rpc/client.py index a96d5e4c..37b09a6f 100644 --- a/transmission_rpc/client.py +++ b/transmission_rpc/client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import importlib.metadata import json import logging import pathlib @@ -7,14 +8,15 @@ import time import types from typing import Any, BinaryIO, Iterable, List, TypeVar, Union -from urllib.parse import quote -import requests -import requests.auth -import requests.exceptions +import certifi +import urllib3 from typing_extensions import Literal, Self, TypedDict, deprecated +from urllib3 import Timeout +from urllib3.util import make_headers -from transmission_rpc.constants import DEFAULT_TIMEOUT, LOGGER, RpcMethod +from transmission_rpc._unix_socket import UnixHTTPConnectionPool +from transmission_rpc.constants import LOGGER, RpcMethod from transmission_rpc.error import ( TransmissionAuthError, TransmissionConnectError, @@ -23,15 +25,27 @@ ) from transmission_rpc.session import Session, SessionStats from transmission_rpc.torrent import Torrent -from transmission_rpc.types import Group, _Timeout +from transmission_rpc.types import Group from transmission_rpc.utils import _try_read_torrent, get_torrent_arguments +try: + __version__ = importlib.metadata.version("transmission-rpc") +except ImportError: + __version__ = "develop" + +__USER_AGENT__ = f"transmission-rpc/{__version__} (https://github.com/trim21/transmission-rpc)" + _hex_chars = frozenset(string.hexdigits.lower()) _TorrentID = Union[int, str] _TorrentIDs = Union[_TorrentID, List[_TorrentID], None] -_header_session_id = "x-transmission-session-id" +_header_session_id_key = "x-transmission-session-id" + +DEFAULT_TIMEOUT = 30.0 + +# urllib3 may remove support for int/float in the future +_Timeout = Union[Timeout, int, float] class ResponseData(TypedDict): @@ -78,16 +92,18 @@ def _parse_torrent_ids(args: Any) -> str | list[str | int]: class Client: + __query_timeout: Timeout | None + def __init__( self, *, - protocol: Literal["http", "https"] = "http", + protocol: Literal["http", "https", "http+unix"] = "http", username: str | None = None, password: str | None = None, host: str = "127.0.0.1", - port: int = 9091, + port: int | None = 9091, path: str = "/transmission/rpc", - timeout: float = DEFAULT_TIMEOUT, + timeout: float | Timeout | None = DEFAULT_TIMEOUT, logger: logging.Logger = LOGGER, ): """ @@ -101,6 +117,9 @@ def __init__( path: rpc request target path, default ``/transmission/rpc`` timeout: logger: + + To connect to a Unix socket, pass "http+unix" as `protocol` and the path to + the socket as `host`. """ if isinstance(logger, logging.Logger): self.logger = logger @@ -108,25 +127,40 @@ def __init__( raise TypeError( "logger must be instance of `logging.Logger`, default: logging.getLogger('transmission-rpc')" ) - self._query_timeout: _Timeout = timeout + if isinstance(timeout, (int, float)): + self.__query_timeout = Timeout(timeout) + elif isinstance(timeout, Timeout) or timeout is None: + self.__query_timeout = timeout + else: + raise TypeError(f"unsupported value {timeout!r}, only Timeout/float/int are supported") - username = quote(username or "", safe="$-_.+!*'(),;&=", encoding="utf8") if username else "" - password = ":" + quote(password or "", safe="$-_.+!*'(),;&=", encoding="utf8") if password else "" - auth = f"{username}{password}@" if (username or password) else "" + if username or password: + self.__auth_headers = make_headers(basic_auth=f"{username}:{password}", user_agent=__USER_AGENT__) + else: + self.__auth_headers = make_headers(user_agent=__USER_AGENT__) if path == "/transmission/": path = "/transmission/rpc" - url = f"{protocol}://{auth}{host}:{port}{path}" + url_host = "localhost" if protocol == "http+unix" else host + url = f"{protocol}://{url_host}{'' if port is None else f':{port}'}{path}" self._url = str(url) + self._path = path + self.__raw_session: dict[str, Any] = {} self.__session_id = "0" + self.__server_version: str = "(unknown)" self.__protocol_version: int = 17 # default 17 - self._http_session = requests.Session() - self._http_session.trust_env = False self.__semver_version = None - self.get_session() + + common_args: dict[str, Any] = {"host": host, "timeout": self.timeout, "retries": False} + self.__http_client = { + "http": urllib3.HTTPConnectionPool(port=port, **common_args), + "https": urllib3.HTTPSConnectionPool(port=port, ca_certs=certifi.where(), **common_args), + "http+unix": UnixHTTPConnectionPool(**common_args), + }[protocol] + self.get_session(arguments=["rpc-version", "rpc-version-semver", "version"]) self.__torrent_get_arguments = get_torrent_arguments(self.__protocol_version) @property @@ -155,82 +189,74 @@ def server_version(self) -> str: return self.__server_version @property - def timeout(self) -> _Timeout: + def timeout(self) -> Timeout | None: """ Get current timeout for HTTP queries. """ - return self._query_timeout + return self.__query_timeout @timeout.setter - def timeout(self, value: _Timeout) -> None: + def timeout(self, value: Timeout) -> None: """ Set timeout for HTTP queries. """ - if isinstance(value, (tuple, list)): - if len(value) != 2: - raise ValueError("timeout tuple can only include 2 numbers elements") - for v in value: - if not isinstance(v, (float, int)): - raise TypeError("element of timeout tuple can only be int or float") - self._query_timeout = (value[0], value[1]) # for type checker - elif value is None: - self._query_timeout = DEFAULT_TIMEOUT - else: - self._query_timeout = float(value) + if not isinstance(value, Timeout): + raise TypeError("must use Timeout instance") + + self.__query_timeout = value @timeout.deleter def timeout(self) -> None: """ Reset the HTTP query timeout to the default. """ - self._query_timeout = DEFAULT_TIMEOUT + self.__query_timeout = Timeout(DEFAULT_TIMEOUT) - @property - def _http_header(self) -> dict[str, str]: - return {_header_session_id: self.__session_id} + def __get_headers(self) -> dict[str, str]: + self.__auth_headers[_header_session_id_key] = self.__session_id + + return self.__auth_headers def _http_query(self, query: dict[str, Any], timeout: _Timeout | None = None) -> str: """ Query Transmission through HTTP. """ request_count = 0 + if timeout is None: - timeout = self.timeout + timeout = self.__query_timeout + while True: if request_count >= 3: raise TransmissionError("too much request, try enable logger to see what happened") - self.logger.debug( - { - "url": self._url, - "headers": self._http_header, - "data": query, - "timeout": timeout, - } - ) + + headers = self.__get_headers() + self.logger.debug({"path": self._path, "headers": headers, "data": query, "timeout": timeout}) request_count += 1 try: - r = self._http_session.post( - self._url, - headers=self._http_header, + r = self.__http_client.request( + "POST", + url=self._path, + headers=headers, json=query, timeout=timeout, ) - except requests.exceptions.Timeout as e: + except urllib3.exceptions.TimeoutError as e: raise TransmissionTimeoutError("timeout when connection to transmission daemon") from e - except requests.exceptions.ConnectionError as e: + except urllib3.exceptions.ConnectionError as e: raise TransmissionConnectError(f"can't connect to transmission daemon: {e!s}") from e - self.logger.debug(r.text) - if r.status_code in {401, 403}: - self.logger.debug(r.request.headers) + self.logger.debug(r.data) + if r.status in {401, 403}: + self.logger.debug(headers) raise TransmissionAuthError("transmission daemon require auth", original=r) - if _header_session_id in r.headers: - self.__session_id = r.headers["x-transmission-session-id"] + if _header_session_id_key in r.headers: + self.__session_id = r.headers[_header_session_id_key] - if r.status_code != 409: - return r.text + if r.status != 409: + return r.data.decode("utf-8") def _request( self, @@ -814,11 +840,20 @@ def queue_down(self, ids: _TorrentIDs, timeout: _Timeout | None = None) -> None: """Move transfer down in the queue.""" self._request(RpcMethod.QueueMoveDown, ids=ids, require_ids=True, timeout=timeout) - def get_session(self, timeout: _Timeout | None = None) -> Session: + def get_session( + self, + timeout: _Timeout | None = None, + arguments: Iterable[str] | None = None, + ) -> Session: """ Get session parameters. See the Session class for more information. """ - self._request(RpcMethod.SessionGet, timeout=timeout) + + data = {} + if arguments: + data["fields"] = list(arguments) + + self._request(RpcMethod.SessionGet, timeout=timeout, arguments=data) self._update_server_version() return Session(fields=self.__raw_session) @@ -1146,7 +1181,7 @@ def __exit__( exc_val: BaseException | None, exc_tb: types.TracebackType | None, ) -> None: - self._http_session.close() + self.__http_client.close() T = TypeVar("T") diff --git a/transmission_rpc/constants.py b/transmission_rpc/constants.py index 52e20fc7..cc9b95a9 100644 --- a/transmission_rpc/constants.py +++ b/transmission_rpc/constants.py @@ -11,9 +11,6 @@ LOGGER.setLevel(logging.ERROR) -DEFAULT_TIMEOUT = 30.0 - - class Priority(enum.IntEnum): Low = -1 Normal = 0 diff --git a/transmission_rpc/error.py b/transmission_rpc/error.py index 2306a05a..65f2cc56 100644 --- a/transmission_rpc/error.py +++ b/transmission_rpc/error.py @@ -4,12 +4,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any import typing_extensions - -if TYPE_CHECKING: - from requests.models import Response +from urllib3 import BaseHTTPResponse class TransmissionError(Exception): @@ -23,7 +21,7 @@ class TransmissionError(Exception): argument: Any | None # rpc call arguments response: Any | None # parsed json response, may be dict with keys 'result' and 'arguments' raw_response: str | None # raw text http response - original: Response | None # original http requests + original: BaseHTTPResponse | None # original http requests def __init__( self, @@ -32,7 +30,7 @@ def __init__( argument: Any | None = None, response: Any | None = None, raw_response: str | None = None, - original: Response | None = None, + original: BaseHTTPResponse | None = None, ): super().__init__() self.message = message diff --git a/transmission_rpc/types.py b/transmission_rpc/types.py index 9dcc0d4e..54e3db21 100644 --- a/transmission_rpc/types.py +++ b/transmission_rpc/types.py @@ -1,12 +1,9 @@ from __future__ import annotations -from typing import Any, NamedTuple, Optional, Tuple, TypeVar, Union +from typing import Any, NamedTuple, TypeVar from transmission_rpc.constants import Priority -_Number = Union[int, float] -_Timeout = Optional[Union[_Number, Tuple[_Number, _Number]]] - T = TypeVar("T")