Skip to content

Commit

Permalink
Add GraphQL transport (#71)
Browse files Browse the repository at this point in the history
Co-authored-by: Gary Yendell <[email protected]>
  • Loading branch information
marcelldls and GDYendell authored Dec 11, 2024
1 parent e2dfc01 commit 39e55d4
Show file tree
Hide file tree
Showing 8 changed files with 380 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"pvi~=0.10.0",
"pytango",
"softioc>=4.5.0",
"strawberry-graphql",
]
dynamic = ["version"]
license.file = "LICENSE"
Expand Down Expand Up @@ -63,7 +64,7 @@ version_file = "src/fastcs/_version.py"

[tool.pyright]
typeCheckingMode = "standard"
reportMissingImports = false # Ignore missing stubs in imported modules
reportMissingImports = false # Ignore missing stubs in imported modules

[tool.pytest.ini_options]
# Run pytest with all our checkers, and don't spam us with massive tracebacks on error
Expand Down
10 changes: 9 additions & 1 deletion src/fastcs/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from .exceptions import LaunchError
from .transport.adapter import TransportAdapter
from .transport.epics.options import EpicsOptions
from .transport.graphQL.options import GraphQLOptions
from .transport.rest.options import RestOptions
from .transport.tango.options import TangoOptions

# Define a type alias for transport options
TransportOptions: TypeAlias = EpicsOptions | TangoOptions | RestOptions
TransportOptions: TypeAlias = EpicsOptions | TangoOptions | RestOptions | GraphQLOptions


class FastCS:
Expand All @@ -38,6 +39,13 @@ def __init__(
self._backend.dispatcher,
transport_options,
)
case GraphQLOptions():
from .transport.graphQL.adapter import GraphQLTransport

self._transport = GraphQLTransport(
controller,
transport_options,
)
case TangoOptions():
from .transport.tango.adapter import TangoTransport

Expand Down
2 changes: 2 additions & 0 deletions src/fastcs/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from .epics.options import EpicsGUIOptions as EpicsGUIOptions
from .epics.options import EpicsIOCOptions as EpicsIOCOptions
from .epics.options import EpicsOptions as EpicsOptions
from .graphQL.options import GraphQLOptions as GraphQLOptions
from .graphQL.options import GraphQLServerOptions as GraphQLServerOptions
from .rest.options import RestOptions as RestOptions
from .rest.options import RestServerOptions as RestServerOptions
from .tango.options import TangoDSROptions as TangoDSROptions
Expand Down
Empty file.
24 changes: 24 additions & 0 deletions src/fastcs/transport/graphQL/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from fastcs.controller import Controller
from fastcs.transport.adapter import TransportAdapter

from .graphQL import GraphQLServer
from .options import GraphQLOptions


class GraphQLTransport(TransportAdapter):
def __init__(
self,
controller: Controller,
options: GraphQLOptions | None = None,
):
self.options = options or GraphQLOptions()
self._server = GraphQLServer(controller)

def create_docs(self) -> None:
raise NotImplementedError

def create_gui(self) -> None:
raise NotImplementedError

def run(self) -> None:
self._server.run(self.options.gql)
172 changes: 172 additions & 0 deletions src/fastcs/transport/graphQL/graphQL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from collections.abc import Awaitable, Callable, Coroutine
from typing import Any

import strawberry
import uvicorn
from strawberry.asgi import GraphQL
from strawberry.tools import create_type
from strawberry.types.field import StrawberryField

from fastcs.attributes import AttrR, AttrRW, AttrW, T
from fastcs.controller import (
BaseController,
Controller,
SingleMapping,
_get_single_mapping,
)
from fastcs.exceptions import FastCSException

from .options import GraphQLServerOptions


class GraphQLServer:
def __init__(self, controller: Controller):
self._controller = controller
self._app = self._create_app()

def _create_app(self) -> GraphQL:
api = GraphQLAPI(self._controller)
schema = api.create_schema()
app = GraphQL(schema)

return app

def run(self, options: GraphQLServerOptions | None = None) -> None:
if options is None:
options = GraphQLServerOptions()

uvicorn.run(
self._app,
host=options.host,
port=options.port,
log_level=options.log_level,
)


class GraphQLAPI:
"""A Strawberry API built dynamically from a Controller"""

def __init__(self, controller: BaseController):
self.queries: list[StrawberryField] = []
self.mutations: list[StrawberryField] = []

api = _get_single_mapping(controller)

self._process_attributes(api)
self._process_commands(api)
self._process_sub_controllers(api)

def _process_attributes(self, api: SingleMapping):
"""Create queries and mutations from api attributes."""
for attr_name, attribute in api.attributes.items():
match attribute:
# mutation for server changes https://graphql.org/learn/queries/
case AttrRW():
self.queries.append(
strawberry.field(_wrap_attr_get(attr_name, attribute))
)
self.mutations.append(
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
)
case AttrR():
self.queries.append(
strawberry.field(_wrap_attr_get(attr_name, attribute))
)
case AttrW():
self.mutations.append(
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
)

def _process_commands(self, api: SingleMapping):
"""Create mutations from api commands"""
for cmd_name, method in api.command_methods.items():
self.mutations.append(
strawberry.mutation(_wrap_command(cmd_name, method.fn, api.controller))
)

def _process_sub_controllers(self, api: SingleMapping):
"""Recursively add fields from the queries and mutations of sub controllers"""
for sub_controller in api.controller.get_sub_controllers().values():
name = "".join(sub_controller.path)
child_tree = GraphQLAPI(sub_controller)
if child_tree.queries:
self.queries.append(
_wrap_as_field(
name, create_type(f"{name}Query", child_tree.queries)
)
)
if child_tree.mutations:
self.mutations.append(
_wrap_as_field(
name, create_type(f"{name}Mutation", child_tree.mutations)
)
)

def create_schema(self) -> strawberry.Schema:
"""Create a Strawberry Schema to load into a GraphQL application."""
if not self.queries:
raise FastCSException(
"Can't create GraphQL transport from Controller with no read attributes"
)

query = create_type("Query", self.queries)
mutation = create_type("Mutation", self.mutations) if self.mutations else None

return strawberry.Schema(query=query, mutation=mutation)


def _wrap_attr_set(
attr_name: str, attribute: AttrW[T]
) -> Callable[[T], Coroutine[Any, Any, None]]:
"""Wrap an attribute in a function with annotations for strawberry"""

async def _dynamic_f(value):
await attribute.process(value)
return value

# Add type annotations for validation, schema, conversions
_dynamic_f.__name__ = attr_name
_dynamic_f.__annotations__["value"] = attribute.datatype.dtype
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype

return _dynamic_f


def _wrap_attr_get(
attr_name: str, attribute: AttrR[T]
) -> Callable[[], Coroutine[Any, Any, Any]]:
"""Wrap an attribute in a function with annotations for strawberry"""

async def _dynamic_f() -> Any:
return attribute.get()

_dynamic_f.__name__ = attr_name
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype

return _dynamic_f


def _wrap_as_field(field_name: str, operation: type) -> StrawberryField:
"""Wrap a strawberry type as a field of a parent type"""

def _dynamic_field():
return operation()

_dynamic_field.__name__ = field_name
_dynamic_field.__annotations__["return"] = operation

return strawberry.field(_dynamic_field)


def _wrap_command(
method_name: str, method: Callable, controller: BaseController
) -> Callable[..., Awaitable[bool]]:
"""Wrap a command in a function with annotations for strawberry"""

async def _dynamic_f() -> bool:
await getattr(controller, method.__name__)()
return True

_dynamic_f.__name__ = method_name

return _dynamic_f
13 changes: 13 additions & 0 deletions src/fastcs/transport/graphQL/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dataclasses import dataclass, field


@dataclass
class GraphQLServerOptions:
host: str = "localhost"
port: int = 8080
log_level: str = "info"


@dataclass
class GraphQLOptions:
gql: GraphQLServerOptions = field(default_factory=GraphQLServerOptions)
Loading

0 comments on commit 39e55d4

Please sign in to comment.