diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index 467612b..5caf5a6 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -8,6 +8,7 @@ from typing import Generic, TypeVar import numpy as np +from numpy.typing import DTypeLike T = TypeVar("T", int, float, bool, str, enum.Enum, np.ndarray) @@ -138,7 +139,7 @@ def initial_value(self) -> T_Enum: @dataclass(frozen=True) class WaveForm(DataType[np.ndarray]): - array_dtype: np.typing.DTypeLike + array_dtype: DTypeLike shape: tuple[int, ...] = (2000,) @property diff --git a/src/fastcs/transport/epics/gui.py b/src/fastcs/transport/epics/gui.py index 8385cba..e6f79c7 100644 --- a/src/fastcs/transport/epics/gui.py +++ b/src/fastcs/transport/epics/gui.py @@ -85,7 +85,7 @@ def _get_attribute_component( case AttrRW(): read_widget = self._get_read_widget(attribute) write_widget = self._get_write_widget(attribute) - if write_widget is None or write_widget is None: + if write_widget is None or read_widget is None: return None return SignalRW( name=name, diff --git a/src/fastcs/transport/epics/ioc.py b/src/fastcs/transport/epics/ioc.py index f4bebb8..5765866 100644 --- a/src/fastcs/transport/epics/ioc.py +++ b/src/fastcs/transport/epics/ioc.py @@ -1,5 +1,4 @@ import asyncio -import warnings from collections.abc import Callable from types import MethodType from typing import Any, Literal @@ -13,8 +12,8 @@ from fastcs.datatypes import DataType, T from fastcs.transport.epics.util import ( builder_callable_from_attribute, - get_callable_from_epics_type, - get_callable_to_epics_type, + cast_from_epics_type, + cast_to_epics_type, record_metadata_from_attribute, record_metadata_from_datatype, ) @@ -154,10 +153,8 @@ def _create_and_link_attribute_pvs(pv_prefix: str, controller: Controller) -> No def _create_and_link_read_pv( pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrR[T] ) -> None: - cast_to_epics_type = get_callable_to_epics_type(attribute.datatype) - async def async_record_set(value: T): - record.set(cast_to_epics_type(value)) + record.set(cast_to_epics_type(attribute.datatype, value)) record = _make_record(f"{pv_prefix}:{pv_name}", attribute) _add_attr_pvi_info(record, pv_prefix, attr_name, "r") @@ -191,14 +188,13 @@ def datatype_updater(datatype: DataType): def _create_and_link_write_pv( pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrW[T] ) -> None: - cast_from_epics_type = get_callable_from_epics_type(attribute.datatype) - cast_to_epics_type = get_callable_to_epics_type(attribute.datatype) - async def on_update(value): - await attribute.process_without_display_update(cast_from_epics_type(value)) + await attribute.process_without_display_update( + cast_from_epics_type(attribute.datatype, value) + ) async def async_write_display(value: T): - record.set(cast_to_epics_type(value), process=False) + record.set(cast_to_epics_type(attribute.datatype, value), process=False) record = _make_record(f"{pv_prefix}:{pv_name}", attribute, on_update=on_update) diff --git a/src/fastcs/transport/epics/util.py b/src/fastcs/transport/epics/util.py index f2c6ba0..84a049e 100644 --- a/src/fastcs/transport/epics/util.py +++ b/src/fastcs/transport/epics/util.py @@ -1,4 +1,3 @@ -from collections.abc import Callable from dataclasses import asdict from softioc import builder @@ -78,35 +77,24 @@ def record_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: return arguments -def get_callable_from_epics_type(datatype: DataType[T]) -> Callable[[object], T]: +def cast_from_epics_type(datatype: DataType[T], value: object) -> T: match datatype: case Enum(): - - def cast_from_epics_type(value: object) -> T: - return datatype.validate(datatype.members[value]) - + return datatype.validate(datatype.members[value]) case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): - - def cast_from_epics_type(value) -> T: - return datatype.validate(value) + return datatype.validate(value) # type: ignore case _: raise ValueError(f"Unsupported datatype {datatype}") - return cast_from_epics_type -def get_callable_to_epics_type(datatype: DataType[T]) -> Callable[[T], object]: +def cast_to_epics_type(datatype: DataType[T], value: T) -> object: match datatype: case Enum(): - - def cast_to_epics_type(value) -> object: - return datatype.index_of(datatype.validate(value)) + return datatype.index_of(datatype.validate(value)) case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): - - def cast_to_epics_type(value) -> object: - return datatype.validate(value) + return datatype.validate(value) case _: raise ValueError(f"Unsupported datatype {datatype}") - return cast_to_epics_type def builder_callable_from_attribute( diff --git a/src/fastcs/transport/rest/rest.py b/src/fastcs/transport/rest/rest.py index 3c68624..10b36a3 100644 --- a/src/fastcs/transport/rest/rest.py +++ b/src/fastcs/transport/rest/rest.py @@ -10,9 +10,9 @@ from .options import RestServerOptions from .util import ( + cast_from_rest_type, + cast_to_rest_type, convert_datatype, - get_cast_method_from_rest_type, - get_cast_method_to_rest_type, ) @@ -58,10 +58,8 @@ def _put_request_body(attribute: AttrW[T]): def _wrap_attr_put( attribute: AttrW[T], ) -> Callable[[T], Coroutine[Any, Any, None]]: - cast_method = get_cast_method_from_rest_type(attribute.datatype) - async def attr_set(request): - await attribute.process(cast_method(request.value)) + await attribute.process(cast_from_rest_type(attribute.datatype, request.value)) # Fast api uses type annotations for validation, schema, conversions attr_set.__annotations__["request"] = _put_request_body(attribute) @@ -86,11 +84,9 @@ def _get_response_body(attribute: AttrR[T]): def _wrap_attr_get( attribute: AttrR[T], ) -> Callable[[], Coroutine[Any, Any, Any]]: - cast_method = get_cast_method_to_rest_type(attribute.datatype) - async def attr_get() -> Any: # Must be any as response_model is set value = attribute.get() # type: ignore - return {"value": cast_method(value)} + return {"value": cast_to_rest_type(attribute.datatype, value)} return attr_get diff --git a/src/fastcs/transport/rest/util.py b/src/fastcs/transport/rest/util.py index f70ce36..6aa232e 100644 --- a/src/fastcs/transport/rest/util.py +++ b/src/fastcs/transport/rest/util.py @@ -1,5 +1,3 @@ -from collections.abc import Callable - import numpy as np from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm @@ -15,33 +13,21 @@ def convert_datatype(datatype: DataType[T]) -> type: return datatype.dtype -def get_cast_method_to_rest_type(datatype: DataType[T]) -> Callable[[T], object]: +def cast_to_rest_type(datatype: DataType[T], value: T) -> object: match datatype: case WaveForm(): - - def cast_to_rest_type(value) -> list: - return value.tolist() + return value.tolist() case datatype if issubclass(type(datatype), REST_ALLOWED_DATATYPES): - - def cast_to_rest_type(value): - return datatype.validate(value) + return datatype.validate(value) case _: raise ValueError(f"Unsupported datatype {datatype}") - return cast_to_rest_type - -def get_cast_method_from_rest_type(datatype: DataType[T]) -> Callable[[object], T]: +def cast_from_rest_type(datatype: DataType[T], value: object) -> T: match datatype: case WaveForm(): - - def cast_from_rest_type(value) -> T: - return datatype.validate(np.array(value, dtype=datatype.array_dtype)) + return datatype.validate(np.array(value, dtype=datatype.array_dtype)) case datatype if issubclass(type(datatype), REST_ALLOWED_DATATYPES): - - def cast_from_rest_type(value) -> T: - return datatype.validate(value) + return datatype.validate(value) # type: ignore case _: raise ValueError(f"Unsupported datatype {datatype}") - - return cast_from_rest_type diff --git a/src/fastcs/transport/tango/dsr.py b/src/fastcs/transport/tango/dsr.py index f38c56e..64dae60 100644 --- a/src/fastcs/transport/tango/dsr.py +++ b/src/fastcs/transport/tango/dsr.py @@ -11,8 +11,8 @@ from .options import TangoDSROptions from .util import ( - get_cast_method_from_tango_type, - get_cast_method_to_tango_type, + cast_from_tango_type, + cast_to_tango_type, get_server_metadata_from_attribute, get_server_metadata_from_datatype, ) @@ -23,23 +23,13 @@ def _wrap_updater_fget( attribute: AttrR, controller: BaseController, ) -> Callable[[Any], Any]: - cast_method = get_cast_method_to_tango_type(attribute.datatype) - async def fget(tango_device: Device): tango_device.info_stream(f"called fget method: {attr_name}") - return cast_method(attribute.get()) + return cast_to_tango_type(attribute.datatype, attribute.get()) return fget -def _tango_display_format(attribute: Attribute) -> str: - match attribute.datatype: - case Float(prec): - return f"%.{prec}" - - return "6.2f" # `tango.server.attribute` default for `format` - - async def _run_threadsafe_blocking( coro: Coroutine[Any, Any, Any], loop: asyncio.AbstractEventLoop ) -> None: @@ -57,11 +47,9 @@ def _wrap_updater_fset( controller: BaseController, loop: asyncio.AbstractEventLoop, ) -> Callable[[Any, Any], Any]: - cast_method = get_cast_method_from_tango_type(attribute.datatype) - - async def fset(tango_device: Device, val): + async def fset(tango_device: Device, value): tango_device.info_stream(f"called fset method: {attr_name}") - coro = attribute.process(val) + coro = attribute.process(cast_from_tango_type(attribute.datatype, value)) await _run_threadsafe_blocking(coro, loop) return fset diff --git a/src/fastcs/transport/tango/util.py b/src/fastcs/transport/tango/util.py index 3e4bb47..dc3d663 100644 --- a/src/fastcs/transport/tango/util.py +++ b/src/fastcs/transport/tango/util.py @@ -1,4 +1,3 @@ -from collections.abc import Callable from dataclasses import asdict from typing import Any @@ -61,33 +60,21 @@ def get_server_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: return arguments -def get_cast_method_to_tango_type(datatype: DataType[T]) -> Callable[[T], object]: +def cast_to_tango_type(datatype: DataType[T], value: T) -> object: match datatype: case Enum(): - - def cast_to_tango_type(value) -> int: - return datatype.index_of(datatype.validate(value)) + return datatype.index_of(datatype.validate(value)) case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): - - def cast_to_tango_type(value) -> object: - return datatype.validate(value) + return datatype.validate(value) case _: raise ValueError(f"Unsupported datatype {datatype}") - return cast_to_tango_type -def get_cast_method_from_tango_type(datatype: DataType[T]) -> Callable[[object], T]: +def cast_from_tango_type(datatype: DataType[T], value: object) -> T: match datatype: case Enum(): - - def cast_from_tango_type(value: object) -> T: - return datatype.validate(datatype.members[value]) - + return datatype.validate(datatype.members[value]) case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): - - def cast_from_tango_type(value) -> T: - return datatype.validate(value) + return datatype.validate(value) # type: ignore case _: raise ValueError(f"Unsupported datatype {datatype}") - - return cast_from_tango_type diff --git a/tests/conftest.py b/tests/conftest.py index acf8619..7169c80 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,14 +10,9 @@ import pytest from aioca import purge_channel_caches -from fastcs.attributes import AttrR, AttrRW, AttrW, Handler, Sender, Updater -from fastcs.controller import Controller, SubController -from fastcs.datatypes import Bool, Float, Int, String -from fastcs.transport.tango.dsr import register_dev -from fastcs.datatypes import Bool, Enum, Float, Int, String, WaveForm -from fastcs.wrappers import command, scan from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.datatypes import Bool, Float, Int, String +from fastcs.transport.tango.dsr import register_dev from tests.assertable_controller import ( TestController, TestHandler, diff --git a/tests/transport/tango/test_dsr.py b/tests/transport/tango/test_dsr.py index 374664f..13f03c0 100644 --- a/tests/transport/tango/test_dsr.py +++ b/tests/transport/tango/test_dsr.py @@ -1,7 +1,6 @@ import asyncio -from unittest import mock - import enum +from unittest import mock import numpy as np import pytest @@ -22,6 +21,8 @@ async def patch_run_threadsafe_blocking(coro, loop): await coro + + class TangoAssertableController(AssertableController): read_int = AttrR(Int(), handler=TestUpdater()) read_write_int = AttrRW(Int(), handler=TestHandler()) @@ -34,6 +35,11 @@ class TangoAssertableController(AssertableController): two_d_waveform = AttrRW(WaveForm(np.int32, (10, 10))) +@pytest.fixture(scope="class") +def assertable_controller(class_mocker: MockerFixture): + return TangoAssertableController(class_mocker) + + class TestTangoDevice: @pytest.fixture(scope="class") def tango_context(self, assertable_controller):