diff --git a/CHANGELOG.md b/CHANGELOG.md index ee50112b..c5ed968d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## 0.14.0 (Unreleased) +- Added `ClientForwardRefsPlugin` to standard plugins. - Re-added `model_rebuild` calls for input types with forward references. diff --git a/README.md b/README.md index 999748c3..c7d4a916 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,8 @@ Ariadne Codegen ships with optional plugins importable from the `ariadne_codegen - [`ariadne_codegen.contrib.extract_operations.ExtractOperationsPlugin`](ariadne_codegen/contrib/extract_operations.py) - This extracts query strings from generated client's methods into separate `operations.py` module. It also modifies the generated client to import these definitions. Generated module name can be customized by adding `operations_module_name="custom_name"` to the `[tool.ariadne-codegen.operations]` section in config. Eg.: +- [`ariadne_codegen.contrib.client_forward_refs.ClientForwardRefsPlugin`](ariadne_codegen/contrib/client_forward_refs.py) - This plugin changes generated client module moving all Pydantic models imports under the `TYPE_CHECKING` condition, making them forward references. This greatly improves the import performance of the `client` module. + ```toml [tool.ariadne-codegen] ... diff --git a/ariadne_codegen/contrib/__init__.py b/ariadne_codegen/contrib/__init__.py index 63e5565f..c9e56fec 100644 --- a/ariadne_codegen/contrib/__init__.py +++ b/ariadne_codegen/contrib/__init__.py @@ -1,5 +1,11 @@ +from .client_forward_refs import ClientForwardRefsPlugin from .extract_operations import ExtractOperationsPlugin from .no_reimports import NoReimportsPlugin from .shorter_results import ShorterResultsPlugin -__all__ = ["ExtractOperationsPlugin", "NoReimportsPlugin", "ShorterResultsPlugin"] +__all__ = [ + "ClientForwardRefsPlugin", + "ExtractOperationsPlugin", + "NoReimportsPlugin", + "ShorterResultsPlugin", +] diff --git a/ariadne_codegen/contrib/client_forward_refs.py b/ariadne_codegen/contrib/client_forward_refs.py new file mode 100644 index 00000000..d3ac9833 --- /dev/null +++ b/ariadne_codegen/contrib/client_forward_refs.py @@ -0,0 +1,395 @@ +""" +Plugin that delays imports of Pydantic models in client module. + +Puts all imports under the `typing.TYPE_CHECKING` flag, making +type annotations for generated client's methods forward references. + +This greatly improves import time of generated `client` module when +there are many Pydantic models. + +Because generated client's methods need type definitions for models +they are using, those models imports will be also inserted in their +bodies. + +This will massively reduce import times for larger projects since you only have +to load the input types when loading the client. + +All input and return types that's used to process the server response will +only be imported when the method is called. +""" + +import ast +from typing import Dict, List, Optional, Set, Union + +from graphql import GraphQLSchema + +from ariadne_codegen import Plugin + +TYPE_CHECKING_MODULE: str = "typing" +TYPE_CHECKING_FLAG: str = "TYPE_CHECKING" + + +class ClientForwardRefsPlugin(Plugin): + """Only import types when you call an endpoint needing it""" + + def __init__(self, schema: GraphQLSchema, config_dict: Dict) -> None: + """Constructor""" + # Types that should only be imported in a `TYPE_CHECKING` context. This + # is all the types used as arguments to a method or as a return type, + # i.e. for type checking. + self.input_and_return_types: Set[str] = set() + + # Imported classes are classes imported from local imports. We keep a + # map between name and module so we know how to import them in each + # method. + self.imported_classes: Dict[str, str] = {} + + # Imported classes in each method definition. + self.imported_in_method: Set[str] = set() + + super().__init__(schema, config_dict) + + def generate_client_module(self, module: ast.Module) -> ast.Module: + """ + Update the generated client. + + This will parse all current imports to map them to a path. It will then + traverse all methods and look for the actual return type. The return + node will be converted to an `ast.Constant` if it's an `ast.Name` and + the return type will be imported only under `if TYPE_CHECKING` + conditions. + + It will also move all imports of the types used to parse the response + inside each method since that's the only place where they're used. The + result will be that we end up with imports in the global scope only for + types used as input types. + + :param module: The ast for the module + :returns: A modified `ast.Module` + """ + self._store_imported_classes(module.body) + + # Find the actual client class so we can grab all input and output + # types. We also ensure to manipulate the ast while we do this. + client_class_def = next( + filter(lambda o: isinstance(o, ast.ClassDef), module.body), None + ) + if not client_class_def or not isinstance(client_class_def, ast.ClassDef): + return super().generate_client_module(module) + + for method_def in [ + m + for m in client_class_def.body + if isinstance(m, (ast.FunctionDef, ast.AsyncFunctionDef)) + ]: + method_def = self._rewrite_input_args_to_constants(method_def) + + # If the method returns anything, update whatever it returns. + if method_def.returns: + method_def.returns = self._update_name_to_constant(method_def.returns) + + self._insert_import_statement_in_method(method_def) + + self._update_imports(module) + + return super().generate_client_module(module) + + def _store_imported_classes(self, module_body: List[ast.stmt]): + """Fetch and store imported classes. + + Grab all imported classes with level 1 or starting with `.` because + these are the ones generated by us. We store a map between the class and + which module it was imported from so we can easily import it when + needed. This can be in a `TYPE_CHECKING` condition or inside a method. + + :param module_body: The body of an `ast.Module` + """ + for node in module_body: + if not isinstance(node, ast.ImportFrom): + continue + + if node.module is None: + continue + + # We only care about local imports from our generated code. + if node.level != 1 and not node.module.startswith("."): + continue + + for name in node.names: + from_ = "." * node.level + node.module + if isinstance(name, ast.alias): + self.imported_classes[name.name] = from_ + + def _rewrite_input_args_to_constants( + self, method_def: Union[ast.FunctionDef, ast.AsyncFunctionDef] + ) -> Union[ast.FunctionDef, ast.AsyncFunctionDef]: + """Rewrite the arguments to a method. + + For any `ast.Name` that requires an import convert it to an + `ast.Constant` instead. The actual class will be noted and imported + in a `TYPE_CHECKING` context. + + :param method_def: Method definition + :returns: The same definition but updated + """ + if not isinstance(method_def, (ast.FunctionDef, ast.AsyncFunctionDef)): + return method_def + + for i, input_arg in enumerate(method_def.args.args): + annotation = input_arg.annotation + if isinstance(annotation, (ast.Name, ast.Subscript, ast.Tuple)): + method_def.args.args[i].annotation = self._update_name_to_constant( + annotation + ) + + return method_def + + def _insert_import_statement_in_method( + self, method_def: Union[ast.FunctionDef, ast.AsyncFunctionDef] + ): + """Insert import statement in method. + + Each method will eventually pass the returned value to a class we've + generated. Since we only need it in the scope of the method ensure we + add it at the top of the method only. It will be removed from the global + scope. + + :param method_def: The method definition to updated + """ + # Find the last statement in the body, the call to this class is + # what we need to import first. + return_stmt = method_def.body[-1] + if isinstance(return_stmt, ast.Return): + call = self._get_call_arg_from_return(return_stmt) + elif isinstance(return_stmt, ast.AsyncFor): + call = self._get_call_arg_from_async_for(return_stmt) + else: + return + + if call is None: + return + + import_class = self._get_class_from_call(call) + if import_class is None: + return + + import_class_id = import_class.id + + # We add the class to our set of imported in methods - these classes + # don't need to be imported at all in the global scope. + self.imported_in_method.add(import_class.id) + method_def.body.insert( + 0, + ast.ImportFrom( + module=self.imported_classes[import_class_id], + names=[import_class], + ), + ) + + def _get_call_arg_from_return(self, return_stmt: ast.Return) -> Optional[ast.Call]: + """Get the class used in the return statement. + + :param return_stmt: The statement used for return + """ + # If it's a call of the class like produced by + # `ShorterResultsPlugin` we have an attribute. + if isinstance(return_stmt.value, ast.Attribute) and isinstance( + return_stmt.value.value, ast.Call + ): + return return_stmt.value.value + + # If not it's just a call statement to the generated class. + if isinstance(return_stmt.value, ast.Call): + return return_stmt.value + + return None + + def _get_call_arg_from_async_for( + self, last_stmt: ast.AsyncFor + ) -> Optional[ast.Call]: + """Get the class used in the yield expression. + + :param last_stmt: The statement used in `ast.AsyncFor` + """ + if isinstance(last_stmt.body, list) and isinstance(last_stmt.body[0], ast.Expr): + body = last_stmt.body[0] + elif isinstance(last_stmt.body, ast.Expr): + body = last_stmt.body + else: + return None + + if not isinstance(body, ast.Expr): + return None + + if not isinstance(body.value, ast.Yield): + return None + + # If it's a call of the class like produced by + # `ShorterResultsPlugin` we have an attribute. + if isinstance(body.value.value, ast.Attribute) and isinstance( + body.value.value.value, ast.Call + ): + return body.value.value.value + + # If not it's just a call statement to the generated class. + if isinstance(body.value.value, ast.Call): + return body.value.value + + return None + + def _get_class_from_call(self, call: ast.Call) -> Optional[ast.Name]: + """Get the class from an `ast.Call`. + + :param call: The `ast.Call` arg + :returns: `ast.Name` or `None` + """ + if not isinstance(call.func, ast.Attribute): + return None + + if not isinstance(call.func.value, ast.Name): + return None + + return call.func.value + + def _update_imports(self, module: ast.Module) -> None: + """Update all imports. + + Iterate over all imports and remove the aliases that we use as input or + return value. These will be moved and added to an `if TYPE_CHECKING` + block. + + We do this by storing all imports that we want to keep in an array, we + then drop all from the body and re-insert the ones to keep. Lastly we + import `TYPE_CHECKING` and add all our imports in the `if TYPE_CHECKING` + block. + + :param module: The ast for the whole module. + """ + # We now know all our input types and all our return types. The return + # types that are _not_ used as import types should be in an `if + # TYPE_CHECKING` import block. + return_types_not_used_as_input = set(self.input_and_return_types) + + # The ones we import in the method don't need to be imported at all - + # unless that's the type we return. This behaviour can differ if you use + # a plugin such as `ShorterResultsPlugin` that will import a type that + # is different from the type returned. + return_types_not_used_as_input |= ( + self.imported_in_method - self.input_and_return_types + ) + + if len(return_types_not_used_as_input) == 0: + return None + + non_empty_imports = self._update_existing_imports( + module, return_types_not_used_as_input + ) + self._add_forward_ref_imports(module, non_empty_imports) + + return None + + def _update_existing_imports( + self, module: ast.Module, return_types_not_used_as_input: set[str] + ) -> List[Union[ast.Import, ast.ImportFrom]]: + """Update existing imports. + + Remove all import or import from statements that would otherwise be + useless after moving them to forward refs. + + It's very important that we get this right, if we keep any `ImportFrom` + that ends up without any names, the formatting will not work! It will + only remove the empty `import from` but not other unused imports. + + :param module: The ast module to update + :param return_types_not_used_as_input: Set of return types not used as + input + """ + non_empty_imports: List[Union[ast.Import, ast.ImportFrom]] = [] + last_import_at = 0 + for i, node in enumerate(module.body): + if isinstance(node, ast.Import): + last_import_at = i + non_empty_imports.append(node) + + if not isinstance(node, ast.ImportFrom): + continue + + last_import_at = i + reduced_names = [] + for name in node.names: + if name.name not in return_types_not_used_as_input: + reduced_names.append(name) + + node.names = reduced_names + + if len(reduced_names) > 0: + non_empty_imports.append(node) + + # We can now remove all imports and re-insert the ones that's not empty. + module.body = non_empty_imports + module.body[last_import_at + 1 :] + + return non_empty_imports + + def _add_forward_ref_imports( + self, + module: ast.Module, + non_empty_imports: List[Union[ast.Import, ast.ImportFrom]], + ) -> None: + """Add forward ref imports. + + Add all the forward ref imports meaning all the types needed for type + checking under the `if TYPE_CHECKING` condition. + """ + type_checking_imports = {} + for cls in self.input_and_return_types: + module_name = self.imported_classes[cls] + if module_name not in type_checking_imports: + type_checking_imports[module_name] = ast.ImportFrom( + module=module_name, names=[] + ) + + type_checking_imports[module_name].names.append(ast.alias(cls)) + + import_if_type_checking = ast.If( + test=ast.Name(id=TYPE_CHECKING_FLAG), + body=list(type_checking_imports.values()), + orelse=[], + ) + + module.body.insert(len(non_empty_imports), import_if_type_checking) + + # Import `TYPE_CHECKING`. + module.body.insert( + len(non_empty_imports), + ast.ImportFrom( + module=TYPE_CHECKING_MODULE, + names=[ast.Name(TYPE_CHECKING_FLAG)], + ), + ) + + def _update_name_to_constant(self, node: ast.expr) -> ast.expr: + """Update return types. + + If the return type contains any type that resolves to an `ast.Name`, + convert it to an `ast.Constant`. We only need the type for type checking + and can avoid importing the type in the global scope unless needed. + + :param node: The ast node used as return type + :returns: A modified ast node + """ + if isinstance(node, ast.Name): + if node.id in self.imported_classes: + self.input_and_return_types.add(node.id) + return ast.Constant(value=node.id) + + if isinstance(node, ast.Subscript): + node.slice = self._update_name_to_constant(node.slice) + return node + + if isinstance(node, ast.Tuple): + for i, _ in enumerate(node.elts): + node.elts[i] = self._update_name_to_constant(node.elts[i]) + + return node + + return node diff --git a/tests/main/clients/client_forward_refs/custom_scalars.py b/tests/main/clients/client_forward_refs/custom_scalars.py new file mode 100644 index 00000000..5feb063a --- /dev/null +++ b/tests/main/clients/client_forward_refs/custom_scalars.py @@ -0,0 +1,14 @@ +SimpleScalar = str + + +class ComplexScalar: + def __init__(self, value: str) -> None: + self.value = value + + +def parse_complex_scalar(value: str) -> ComplexScalar: + return ComplexScalar(value) + + +def serialize_complex_scalar(value: ComplexScalar) -> str: + return value.value diff --git a/tests/main/clients/client_forward_refs/expected_client/__init__.py b/tests/main/clients/client_forward_refs/expected_client/__init__.py new file mode 100644 index 00000000..d518989d --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/__init__.py @@ -0,0 +1,76 @@ +from .async_base_client import AsyncBaseClient +from .base_model import BaseModel, Upload +from .client import Client +from .client_forward_refs_fragments import ( + FragmentWithSingleField, + FragmentWithSingleFieldQueryUnwrapFragment, + ListAnimalsFragment, + ListAnimalsFragmentListAnimals, +) +from .exceptions import ( + GraphQLClientError, + GraphQLClientGraphQLError, + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidResponseError, +) +from .get_animal_by_name import ( + GetAnimalByName, + GetAnimalByNameAnimalByNameAnimal, + GetAnimalByNameAnimalByNameCat, + GetAnimalByNameAnimalByNameDog, +) +from .get_animal_fragment_with_extra import GetAnimalFragmentWithExtra +from .get_authenticated_user import GetAuthenticatedUser, GetAuthenticatedUserMe +from .get_complex_scalar import GetComplexScalar +from .get_simple_scalar import GetSimpleScalar +from .list_animals import ( + ListAnimals, + ListAnimalsListAnimalsAnimal, + ListAnimalsListAnimalsCat, + ListAnimalsListAnimalsDog, +) +from .list_strings_1 import ListStrings1 +from .list_strings_2 import ListStrings2 +from .list_strings_3 import ListStrings3 +from .list_strings_4 import ListStrings4 +from .list_type_a import ListTypeA, ListTypeAListOptionalTypeA +from .subscribe_strings import SubscribeStrings +from .unwrap_fragment import UnwrapFragment + +__all__ = [ + "AsyncBaseClient", + "BaseModel", + "Client", + "FragmentWithSingleField", + "FragmentWithSingleFieldQueryUnwrapFragment", + "GetAnimalByName", + "GetAnimalByNameAnimalByNameAnimal", + "GetAnimalByNameAnimalByNameCat", + "GetAnimalByNameAnimalByNameDog", + "GetAnimalFragmentWithExtra", + "GetAuthenticatedUser", + "GetAuthenticatedUserMe", + "GetComplexScalar", + "GetSimpleScalar", + "GraphQLClientError", + "GraphQLClientGraphQLError", + "GraphQLClientGraphQLMultiError", + "GraphQLClientHttpError", + "GraphQLClientInvalidResponseError", + "ListAnimals", + "ListAnimalsFragment", + "ListAnimalsFragmentListAnimals", + "ListAnimalsListAnimalsAnimal", + "ListAnimalsListAnimalsCat", + "ListAnimalsListAnimalsDog", + "ListStrings1", + "ListStrings2", + "ListStrings3", + "ListStrings4", + "ListTypeA", + "ListTypeAListOptionalTypeA", + "SubscribeStrings", + "UnwrapFragment", + "Upload", +] diff --git a/tests/main/clients/client_forward_refs/expected_client/async_base_client.py b/tests/main/clients/client_forward_refs/expected_client/async_base_client.py new file mode 100644 index 00000000..5358ced6 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/async_base_client.py @@ -0,0 +1,370 @@ +import enum +import json +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from uuid import uuid4 + +import httpx +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidMessageFormat, + GraphQLClientInvalidResponseError, +) + +try: + from websockets.client import ( # type: ignore[import-not-found,unused-ignore] + WebSocketClientProtocol, + connect as ws_connect, + ) + from websockets.typing import ( # type: ignore[import-not-found,unused-ignore] + Data, + Origin, + Subprotocol, + ) +except ImportError: + from contextlib import asynccontextmanager + + @asynccontextmanager # type: ignore + async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument + raise NotImplementedError("Subscriptions require 'websockets' package.") + yield # pylint: disable=unreachable + + WebSocketClientProtocol = Any # type: ignore[misc,assignment,unused-ignore] + Data = Any # type: ignore[misc,assignment,unused-ignore] + Origin = Any # type: ignore[misc,assignment,unused-ignore] + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") + + +Self = TypeVar("Self", bound="AsyncBaseClient") + +GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" + + +class GraphQLTransportWSMessageType(str, enum.Enum): + CONNECTION_INIT = "connection_init" + CONNECTION_ACK = "connection_ack" + PING = "ping" + PONG = "pong" + SUBSCRIBE = "subscribe" + NEXT = "next" + ERROR = "error" + COMPLETE = "complete" + + +class AsyncBaseClient: + def __init__( + self, + url: str = "", + headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.AsyncClient] = None, + ws_url: str = "", + ws_headers: Optional[Dict[str, Any]] = None, + ws_origin: Optional[str] = None, + ws_connection_init_payload: Optional[Dict[str, Any]] = None, + ) -> None: + self.url = url + self.headers = headers + self.http_client = ( + http_client if http_client else httpx.AsyncClient(headers=headers) + ) + + self.ws_url = ws_url + self.ws_headers = ws_headers or {} + self.ws_origin = Origin(ws_origin) if ws_origin else None + self.ws_connection_init_payload = ws_connection_init_payload + + async def __aenter__(self: Self) -> Self: + return self + + async def __aexit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + await self.http_client.aclose() + + async def execute( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + operation_name=operation_name, + variables=processed_variables, + files=files, + files_map=files_map, + **kwargs, + ) + + return await self._execute_json( + query=query, + operation_name=operation_name, + variables=processed_variables, + **kwargs, + ) + + def get_data(self, response: httpx.Response) -> Dict[str, Any]: + if not response.is_success: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQLClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ( + "data" not in response_json and "errors" not in response_json + ): + raise GraphQLClientInvalidResponseError(response=response) + + data = response_json.get("data") + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(Dict[str, Any], data) + + async def execute_ws( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + **merged_kwargs, + ) as websocket: + await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + operation_name=operation_name, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + async def _execute_multipart( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + **kwargs: Any, + ) -> httpx.Response: + data = { + "operations": json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) + + async def _execute_json( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + **kwargs: Any, + ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + + return await self.http_client.post( + url=self.url, + content=json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + **merged_kwargs, + ) + + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: + payload: Dict[str, Any] = { + "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value + } + if self.ws_connection_init_payload: + payload["payload"] = self.ws_connection_init_payload + await websocket.send(json.dumps(payload)) + + async def _send_subscribe( + self, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + payload: Dict[str, Any] = { + "id": operation_id, + "type": GraphQLTransportWSMessageType.SUBSCRIBE.value, + "payload": {"query": query, "operationName": operation_name}, + } + if variables: + payload["payload"]["variables"] = self._convert_dict_to_json_serializable( + variables + ) + await websocket.send(json.dumps(payload)) + + async def _handle_ws_message( + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, + ) -> Optional[Dict[str, Any]]: + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: + raise GraphQLClientInvalidMessageFormat(message=message) + + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/client_forward_refs/expected_client/base_model.py b/tests/main/clients/client_forward_refs/expected_client/base_model.py new file mode 100644 index 00000000..ccde3975 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/base_model.py @@ -0,0 +1,27 @@ +from io import IOBase + +from pydantic import BaseModel as PydanticBaseModel, ConfigDict + + +class UnsetType: + def __bool__(self) -> bool: + return False + + +UNSET = UnsetType() + + +class BaseModel(PydanticBaseModel): + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + arbitrary_types_allowed=True, + protected_namespaces=(), + ) + + +class Upload: + def __init__(self, filename: str, content: IOBase, content_type: str): + self.filename = filename + self.content = content + self.content_type = content_type diff --git a/tests/main/clients/client_forward_refs/expected_client/client.py b/tests/main/clients/client_forward_refs/expected_client/client.py new file mode 100644 index 00000000..9e80af45 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/client.py @@ -0,0 +1,296 @@ +from typing import TYPE_CHECKING, Any, AsyncIterator, Dict + +from .async_base_client import AsyncBaseClient + +if TYPE_CHECKING: + from .get_animal_by_name import GetAnimalByName + from .get_animal_fragment_with_extra import GetAnimalFragmentWithExtra + from .get_authenticated_user import GetAuthenticatedUser + from .get_complex_scalar import GetComplexScalar + from .get_simple_scalar import GetSimpleScalar + from .list_animals import ListAnimals + from .list_strings_1 import ListStrings1 + from .list_strings_2 import ListStrings2 + from .list_strings_3 import ListStrings3 + from .list_strings_4 import ListStrings4 + from .list_type_a import ListTypeA + from .subscribe_strings import SubscribeStrings + from .unwrap_fragment import UnwrapFragment + + +def gql(q: str) -> str: + return q + + +class Client(AsyncBaseClient): + async def get_authenticated_user(self, **kwargs: Any) -> "GetAuthenticatedUser": + from .get_authenticated_user import GetAuthenticatedUser + + query = gql( + """ + query GetAuthenticatedUser { + me { + id + username + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, + operation_name="GetAuthenticatedUser", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return GetAuthenticatedUser.model_validate(data) + + async def list_strings_1(self, **kwargs: Any) -> "ListStrings1": + from .list_strings_1 import ListStrings1 + + query = gql( + """ + query ListStrings_1 { + optionalListOptionalString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_1", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings1.model_validate(data) + + async def list_strings_2(self, **kwargs: Any) -> "ListStrings2": + from .list_strings_2 import ListStrings2 + + query = gql( + """ + query ListStrings_2 { + optionalListString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_2", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings2.model_validate(data) + + async def list_strings_3(self, **kwargs: Any) -> "ListStrings3": + from .list_strings_3 import ListStrings3 + + query = gql( + """ + query ListStrings_3 { + listOptionalString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_3", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings3.model_validate(data) + + async def list_strings_4(self, **kwargs: Any) -> "ListStrings4": + from .list_strings_4 import ListStrings4 + + query = gql( + """ + query ListStrings_4 { + listString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_4", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings4.model_validate(data) + + async def list_type_a(self, **kwargs: Any) -> "ListTypeA": + from .list_type_a import ListTypeA + + query = gql( + """ + query ListTypeA { + listOptionalTypeA { + id + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListTypeA", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListTypeA.model_validate(data) + + async def get_animal_by_name(self, name: str, **kwargs: Any) -> "GetAnimalByName": + from .get_animal_by_name import GetAnimalByName + + query = gql( + """ + query GetAnimalByName($name: String!) { + animalByName(name: $name) { + __typename + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } + } + """ + ) + variables: Dict[str, object] = {"name": name} + response = await self.execute( + query=query, operation_name="GetAnimalByName", variables=variables, **kwargs + ) + data = self.get_data(response) + return GetAnimalByName.model_validate(data) + + async def list_animals(self, **kwargs: Any) -> "ListAnimals": + from .list_animals import ListAnimals + + query = gql( + """ + query ListAnimals { + listAnimals { + __typename + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListAnimals", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListAnimals.model_validate(data) + + async def get_animal_fragment_with_extra( + self, **kwargs: Any + ) -> "GetAnimalFragmentWithExtra": + from .get_animal_fragment_with_extra import GetAnimalFragmentWithExtra + + query = gql( + """ + query GetAnimalFragmentWithExtra { + ...ListAnimalsFragment + listString + } + + fragment ListAnimalsFragment on Query { + listAnimals { + name + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, + operation_name="GetAnimalFragmentWithExtra", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return GetAnimalFragmentWithExtra.model_validate(data) + + async def get_simple_scalar(self, **kwargs: Any) -> "GetSimpleScalar": + from .get_simple_scalar import GetSimpleScalar + + query = gql( + """ + query GetSimpleScalar { + justSimpleScalar + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="GetSimpleScalar", variables=variables, **kwargs + ) + data = self.get_data(response) + return GetSimpleScalar.model_validate(data) + + async def get_complex_scalar(self, **kwargs: Any) -> "GetComplexScalar": + from .get_complex_scalar import GetComplexScalar + + query = gql( + """ + query GetComplexScalar { + justComplexScalar + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, + operation_name="GetComplexScalar", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return GetComplexScalar.model_validate(data) + + async def subscribe_strings( + self, **kwargs: Any + ) -> AsyncIterator["SubscribeStrings"]: + from .subscribe_strings import SubscribeStrings + + query = gql( + """ + subscription SubscribeStrings { + optionalListString + } + """ + ) + variables: Dict[str, object] = {} + async for data in self.execute_ws( + query=query, + operation_name="SubscribeStrings", + variables=variables, + **kwargs + ): + yield SubscribeStrings.model_validate(data) + + async def unwrap_fragment(self, **kwargs: Any) -> "UnwrapFragment": + from .unwrap_fragment import UnwrapFragment + + query = gql( + """ + query UnwrapFragment { + ...FragmentWithSingleField + } + + fragment FragmentWithSingleField on Query { + queryUnwrapFragment { + id + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="UnwrapFragment", variables=variables, **kwargs + ) + data = self.get_data(response) + return UnwrapFragment.model_validate(data) diff --git a/tests/main/clients/client_forward_refs/expected_client/client_forward_refs_fragments.py b/tests/main/clients/client_forward_refs/expected_client/client_forward_refs_fragments.py new file mode 100644 index 00000000..4b103a44 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/client_forward_refs_fragments.py @@ -0,0 +1,28 @@ +from typing import List, Literal + +from pydantic import Field + +from .base_model import BaseModel + + +class FragmentWithSingleField(BaseModel): + query_unwrap_fragment: "FragmentWithSingleFieldQueryUnwrapFragment" = Field( + alias="queryUnwrapFragment" + ) + + +class FragmentWithSingleFieldQueryUnwrapFragment(BaseModel): + id: int + + +class ListAnimalsFragment(BaseModel): + list_animals: List["ListAnimalsFragmentListAnimals"] = Field(alias="listAnimals") + + +class ListAnimalsFragmentListAnimals(BaseModel): + typename__: Literal["Animal", "Cat", "Dog"] = Field(alias="__typename") + name: str + + +FragmentWithSingleField.model_rebuild() +ListAnimalsFragment.model_rebuild() diff --git a/tests/main/clients/client_forward_refs/expected_client/custom_scalars.py b/tests/main/clients/client_forward_refs/expected_client/custom_scalars.py new file mode 100644 index 00000000..5feb063a --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/custom_scalars.py @@ -0,0 +1,14 @@ +SimpleScalar = str + + +class ComplexScalar: + def __init__(self, value: str) -> None: + self.value = value + + +def parse_complex_scalar(value: str) -> ComplexScalar: + return ComplexScalar(value) + + +def serialize_complex_scalar(value: ComplexScalar) -> str: + return value.value diff --git a/tests/main/clients/client_forward_refs/expected_client/enums.py b/tests/main/clients/client_forward_refs/expected_client/enums.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/main/clients/client_forward_refs/expected_client/exceptions.py b/tests/main/clients/client_forward_refs/expected_client/exceptions.py new file mode 100644 index 00000000..b34acfe1 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/exceptions.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + + +class GraphQLClientError(Exception): + """Base exception.""" + + +class GraphQLClientHttpError(GraphQLClientError): + def __init__(self, status_code: int, response: httpx.Response) -> None: + self.status_code = status_code + self.response = response + + def __str__(self) -> str: + return f"HTTP status code: {self.status_code}" + + +class GraphQLClientInvalidResponseError(GraphQLClientError): + def __init__(self, response: httpx.Response) -> None: + self.response = response + + def __str__(self) -> str: + return "Invalid response format." + + +class GraphQLClientGraphQLError(GraphQLClientError): + def __init__( + self, + message: str, + locations: Optional[List[Dict[str, int]]] = None, + path: Optional[List[str]] = None, + extensions: Optional[Dict[str, object]] = None, + orginal: Optional[Dict[str, object]] = None, + ): + self.message = message + self.locations = locations + self.path = path + self.extensions = extensions + self.orginal = orginal + + def __str__(self) -> str: + return self.message + + @classmethod + def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": + return cls( + message=error["message"], + locations=error.get("locations"), + path=error.get("path"), + extensions=error.get("extensions"), + orginal=error, + ) + + +class GraphQLClientGraphQLMultiError(GraphQLClientError): + def __init__( + self, + errors: List[GraphQLClientGraphQLError], + data: Optional[Dict[str, Any]] = None, + ): + self.errors = errors + self.data = data + + def __str__(self) -> str: + return "; ".join(str(e) for e in self.errors) + + @classmethod + def from_errors_dicts( + cls, errors_dicts: List[Dict[str, Any]], data: Optional[Dict[str, Any]] = None + ) -> "GraphQLClientGraphQLMultiError": + return cls( + errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts], + data=data, + ) + + +class GraphQLClientInvalidMessageFormat(GraphQLClientError): + def __init__(self, message: Union[str, bytes]) -> None: + self.message = message + + def __str__(self) -> str: + return "Invalid message format." diff --git a/tests/main/clients/client_forward_refs/expected_client/get_animal_by_name.py b/tests/main/clients/client_forward_refs/expected_client/get_animal_by_name.py new file mode 100644 index 00000000..e97e12ac --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/get_animal_by_name.py @@ -0,0 +1,33 @@ +from typing import Literal, Union + +from pydantic import Field + +from .base_model import BaseModel + + +class GetAnimalByName(BaseModel): + animal_by_name: Union[ + "GetAnimalByNameAnimalByNameAnimal", + "GetAnimalByNameAnimalByNameCat", + "GetAnimalByNameAnimalByNameDog", + ] = Field(alias="animalByName", discriminator="typename__") + + +class GetAnimalByNameAnimalByNameAnimal(BaseModel): + typename__: Literal["Animal"] = Field(alias="__typename") + name: str + + +class GetAnimalByNameAnimalByNameCat(BaseModel): + typename__: Literal["Cat"] = Field(alias="__typename") + name: str + kittens: int + + +class GetAnimalByNameAnimalByNameDog(BaseModel): + typename__: Literal["Dog"] = Field(alias="__typename") + name: str + puppies: int + + +GetAnimalByName.model_rebuild() diff --git a/tests/main/clients/client_forward_refs/expected_client/get_animal_fragment_with_extra.py b/tests/main/clients/client_forward_refs/expected_client/get_animal_fragment_with_extra.py new file mode 100644 index 00000000..bf363706 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/get_animal_fragment_with_extra.py @@ -0,0 +1,9 @@ +from typing import List + +from pydantic import Field + +from .client_forward_refs_fragments import ListAnimalsFragment + + +class GetAnimalFragmentWithExtra(ListAnimalsFragment): + list_string: List[str] = Field(alias="listString") diff --git a/tests/main/clients/client_forward_refs/expected_client/get_authenticated_user.py b/tests/main/clients/client_forward_refs/expected_client/get_authenticated_user.py new file mode 100644 index 00000000..ce0e2cf7 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/get_authenticated_user.py @@ -0,0 +1,13 @@ +from .base_model import BaseModel + + +class GetAuthenticatedUser(BaseModel): + me: "GetAuthenticatedUserMe" + + +class GetAuthenticatedUserMe(BaseModel): + id: str + username: str + + +GetAuthenticatedUser.model_rebuild() diff --git a/tests/main/clients/client_forward_refs/expected_client/get_complex_scalar.py b/tests/main/clients/client_forward_refs/expected_client/get_complex_scalar.py new file mode 100644 index 00000000..78563772 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/get_complex_scalar.py @@ -0,0 +1,12 @@ +from typing import Annotated + +from pydantic import BeforeValidator, Field + +from .base_model import BaseModel +from .custom_scalars import ComplexScalar, parse_complex_scalar + + +class GetComplexScalar(BaseModel): + just_complex_scalar: Annotated[ + ComplexScalar, BeforeValidator(parse_complex_scalar) + ] = Field(alias="justComplexScalar") diff --git a/tests/main/clients/client_forward_refs/expected_client/get_simple_scalar.py b/tests/main/clients/client_forward_refs/expected_client/get_simple_scalar.py new file mode 100644 index 00000000..365f7eaa --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/get_simple_scalar.py @@ -0,0 +1,8 @@ +from pydantic import Field + +from .base_model import BaseModel +from .custom_scalars import SimpleScalar + + +class GetSimpleScalar(BaseModel): + just_simple_scalar: SimpleScalar = Field(alias="justSimpleScalar") diff --git a/tests/main/clients/client_forward_refs/expected_client/input_types.py b/tests/main/clients/client_forward_refs/expected_client/input_types.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/main/clients/client_forward_refs/expected_client/list_animals.py b/tests/main/clients/client_forward_refs/expected_client/list_animals.py new file mode 100644 index 00000000..bcab1c45 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/list_animals.py @@ -0,0 +1,38 @@ +from typing import Annotated, List, Literal, Union + +from pydantic import Field + +from .base_model import BaseModel + + +class ListAnimals(BaseModel): + list_animals: List[ + Annotated[ + Union[ + "ListAnimalsListAnimalsAnimal", + "ListAnimalsListAnimalsCat", + "ListAnimalsListAnimalsDog", + ], + Field(discriminator="typename__"), + ] + ] = Field(alias="listAnimals") + + +class ListAnimalsListAnimalsAnimal(BaseModel): + typename__: Literal["Animal"] = Field(alias="__typename") + name: str + + +class ListAnimalsListAnimalsCat(BaseModel): + typename__: Literal["Cat"] = Field(alias="__typename") + name: str + kittens: int + + +class ListAnimalsListAnimalsDog(BaseModel): + typename__: Literal["Dog"] = Field(alias="__typename") + name: str + puppies: int + + +ListAnimals.model_rebuild() diff --git a/tests/main/clients/client_forward_refs/expected_client/list_strings_1.py b/tests/main/clients/client_forward_refs/expected_client/list_strings_1.py new file mode 100644 index 00000000..fd8c06de --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/list_strings_1.py @@ -0,0 +1,11 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings1(BaseModel): + optional_list_optional_string: Optional[List[Optional[str]]] = Field( + alias="optionalListOptionalString" + ) diff --git a/tests/main/clients/client_forward_refs/expected_client/list_strings_2.py b/tests/main/clients/client_forward_refs/expected_client/list_strings_2.py new file mode 100644 index 00000000..d91ec117 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/list_strings_2.py @@ -0,0 +1,9 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings2(BaseModel): + optional_list_string: Optional[List[str]] = Field(alias="optionalListString") diff --git a/tests/main/clients/client_forward_refs/expected_client/list_strings_3.py b/tests/main/clients/client_forward_refs/expected_client/list_strings_3.py new file mode 100644 index 00000000..88f6e2cf --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/list_strings_3.py @@ -0,0 +1,9 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings3(BaseModel): + list_optional_string: List[Optional[str]] = Field(alias="listOptionalString") diff --git a/tests/main/clients/client_forward_refs/expected_client/list_strings_4.py b/tests/main/clients/client_forward_refs/expected_client/list_strings_4.py new file mode 100644 index 00000000..15872b23 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/list_strings_4.py @@ -0,0 +1,9 @@ +from typing import List + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings4(BaseModel): + list_string: List[str] = Field(alias="listString") diff --git a/tests/main/clients/client_forward_refs/expected_client/list_type_a.py b/tests/main/clients/client_forward_refs/expected_client/list_type_a.py new file mode 100644 index 00000000..9e2c8f04 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/list_type_a.py @@ -0,0 +1,18 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListTypeA(BaseModel): + list_optional_type_a: List[Optional["ListTypeAListOptionalTypeA"]] = Field( + alias="listOptionalTypeA" + ) + + +class ListTypeAListOptionalTypeA(BaseModel): + id: int + + +ListTypeA.model_rebuild() diff --git a/tests/main/clients/client_forward_refs/expected_client/subscribe_strings.py b/tests/main/clients/client_forward_refs/expected_client/subscribe_strings.py new file mode 100644 index 00000000..a1a42d4f --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/subscribe_strings.py @@ -0,0 +1,9 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class SubscribeStrings(BaseModel): + optional_list_string: Optional[List[str]] = Field(alias="optionalListString") diff --git a/tests/main/clients/client_forward_refs/expected_client/unwrap_fragment.py b/tests/main/clients/client_forward_refs/expected_client/unwrap_fragment.py new file mode 100644 index 00000000..83423282 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/unwrap_fragment.py @@ -0,0 +1,5 @@ +from .client_forward_refs_fragments import FragmentWithSingleField + + +class UnwrapFragment(FragmentWithSingleField): + pass diff --git a/tests/main/clients/client_forward_refs/pyproject.toml b/tests/main/clients/client_forward_refs/pyproject.toml new file mode 100644 index 00000000..6f4570ba --- /dev/null +++ b/tests/main/clients/client_forward_refs/pyproject.toml @@ -0,0 +1,16 @@ +[tool.ariadne-codegen] +schema_path = "schema.graphql" +queries_path = "queries.graphql" +include_comments = "none" +target_package_name = "client_forward_refs" +files_to_include = ["custom_scalars.py"] +fragments_module_name = "client_forward_refs_fragments" +plugins = ["ariadne_codegen.contrib.client_forward_refs.ClientForwardRefsPlugin"] + +[tool.ariadne-codegen.scalars.SimpleScalar] +type = ".custom_scalars.SimpleScalar" + +[tool.ariadne-codegen.scalars.ComplexScalar] +type = ".custom_scalars.ComplexScalar" +parse = ".custom_scalars.parse_complex_scalar" +serialize = ".custom_scalars.serialize_complex_scalar" diff --git a/tests/main/clients/client_forward_refs/queries.graphql b/tests/main/clients/client_forward_refs/queries.graphql new file mode 100644 index 00000000..d407732c --- /dev/null +++ b/tests/main/clients/client_forward_refs/queries.graphql @@ -0,0 +1,86 @@ +query GetAuthenticatedUser { + me { + id + username + } +} + +query ListStrings_1 { + optionalListOptionalString +} + +query ListStrings_2 { + optionalListString +} + +query ListStrings_3 { + listOptionalString +} + +query ListStrings_4 { + listString +} + +query ListTypeA { + listOptionalTypeA { + id + } +} + +query GetAnimalByName($name: String!) { + animalByName(name: $name) { + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } +} + +query ListAnimals { + listAnimals { + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } +} + +query GetAnimalFragmentWithExtra { + ...ListAnimalsFragment + listString +} + +query GetSimpleScalar { + justSimpleScalar +} + +query GetComplexScalar { + justComplexScalar +} + +fragment ListAnimalsFragment on Query { + listAnimals { + name + } +} + +subscription SubscribeStrings { + optionalListString +} + + +fragment FragmentWithSingleField on Query { + queryUnwrapFragment { + id + } +} + +query UnwrapFragment { + ...FragmentWithSingleField +} diff --git a/tests/main/clients/client_forward_refs/schema.graphql b/tests/main/clients/client_forward_refs/schema.graphql new file mode 100644 index 00000000..7acc7f92 --- /dev/null +++ b/tests/main/clients/client_forward_refs/schema.graphql @@ -0,0 +1,48 @@ +scalar SimpleScalar +scalar ComplexScalar + +type TypeA { + id: Int! +} + +type User { + id: ID! + username: String! +} + +interface Animal { + name: String! +} + +type Cat implements Animal { + name: String! + kittens: Int! +} + +type Dog implements Animal { + name: String! + puppies: Int! +} + +type Query { + optionalListOptionalString: [String] + optionalListString: [String!] + listOptionalString: [String]! + listString: [String!]! + + listOptionalTypeA: [TypeA]! + + me: User! + + listAnimals: [Animal!]! + animalByName(name: String!): Animal! + + justSimpleScalar: SimpleScalar! + justComplexScalar: ComplexScalar! + + queryUnwrapFragment: TypeA! +} + +type Subscription { + optionalListString: [String!] +} diff --git a/tests/main/clients/client_forward_refs_shorter_results/custom_scalars.py b/tests/main/clients/client_forward_refs_shorter_results/custom_scalars.py new file mode 100644 index 00000000..5feb063a --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/custom_scalars.py @@ -0,0 +1,14 @@ +SimpleScalar = str + + +class ComplexScalar: + def __init__(self, value: str) -> None: + self.value = value + + +def parse_complex_scalar(value: str) -> ComplexScalar: + return ComplexScalar(value) + + +def serialize_complex_scalar(value: ComplexScalar) -> str: + return value.value diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/__init__.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/__init__.py new file mode 100644 index 00000000..51ff9a6c --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/__init__.py @@ -0,0 +1,76 @@ +from .async_base_client import AsyncBaseClient +from .base_model import BaseModel, Upload +from .client import Client +from .client_forward_refs_shorter_resultsfragments import ( + FragmentWithSingleField, + FragmentWithSingleFieldQueryUnwrapFragment, + ListAnimalsFragment, + ListAnimalsFragmentListAnimals, +) +from .exceptions import ( + GraphQLClientError, + GraphQLClientGraphQLError, + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidResponseError, +) +from .get_animal_by_name import ( + GetAnimalByName, + GetAnimalByNameAnimalByNameAnimal, + GetAnimalByNameAnimalByNameCat, + GetAnimalByNameAnimalByNameDog, +) +from .get_animal_fragment_with_extra import GetAnimalFragmentWithExtra +from .get_authenticated_user import GetAuthenticatedUser, GetAuthenticatedUserMe +from .get_complex_scalar import GetComplexScalar +from .get_simple_scalar import GetSimpleScalar +from .list_animals import ( + ListAnimals, + ListAnimalsListAnimalsAnimal, + ListAnimalsListAnimalsCat, + ListAnimalsListAnimalsDog, +) +from .list_strings_1 import ListStrings1 +from .list_strings_2 import ListStrings2 +from .list_strings_3 import ListStrings3 +from .list_strings_4 import ListStrings4 +from .list_type_a import ListTypeA, ListTypeAListOptionalTypeA +from .subscribe_strings import SubscribeStrings +from .unwrap_fragment import UnwrapFragment + +__all__ = [ + "AsyncBaseClient", + "BaseModel", + "Client", + "FragmentWithSingleField", + "FragmentWithSingleFieldQueryUnwrapFragment", + "GetAnimalByName", + "GetAnimalByNameAnimalByNameAnimal", + "GetAnimalByNameAnimalByNameCat", + "GetAnimalByNameAnimalByNameDog", + "GetAnimalFragmentWithExtra", + "GetAuthenticatedUser", + "GetAuthenticatedUserMe", + "GetComplexScalar", + "GetSimpleScalar", + "GraphQLClientError", + "GraphQLClientGraphQLError", + "GraphQLClientGraphQLMultiError", + "GraphQLClientHttpError", + "GraphQLClientInvalidResponseError", + "ListAnimals", + "ListAnimalsFragment", + "ListAnimalsFragmentListAnimals", + "ListAnimalsListAnimalsAnimal", + "ListAnimalsListAnimalsCat", + "ListAnimalsListAnimalsDog", + "ListStrings1", + "ListStrings2", + "ListStrings3", + "ListStrings4", + "ListTypeA", + "ListTypeAListOptionalTypeA", + "SubscribeStrings", + "UnwrapFragment", + "Upload", +] diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/async_base_client.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/async_base_client.py new file mode 100644 index 00000000..5358ced6 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/async_base_client.py @@ -0,0 +1,370 @@ +import enum +import json +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from uuid import uuid4 + +import httpx +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidMessageFormat, + GraphQLClientInvalidResponseError, +) + +try: + from websockets.client import ( # type: ignore[import-not-found,unused-ignore] + WebSocketClientProtocol, + connect as ws_connect, + ) + from websockets.typing import ( # type: ignore[import-not-found,unused-ignore] + Data, + Origin, + Subprotocol, + ) +except ImportError: + from contextlib import asynccontextmanager + + @asynccontextmanager # type: ignore + async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument + raise NotImplementedError("Subscriptions require 'websockets' package.") + yield # pylint: disable=unreachable + + WebSocketClientProtocol = Any # type: ignore[misc,assignment,unused-ignore] + Data = Any # type: ignore[misc,assignment,unused-ignore] + Origin = Any # type: ignore[misc,assignment,unused-ignore] + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") + + +Self = TypeVar("Self", bound="AsyncBaseClient") + +GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" + + +class GraphQLTransportWSMessageType(str, enum.Enum): + CONNECTION_INIT = "connection_init" + CONNECTION_ACK = "connection_ack" + PING = "ping" + PONG = "pong" + SUBSCRIBE = "subscribe" + NEXT = "next" + ERROR = "error" + COMPLETE = "complete" + + +class AsyncBaseClient: + def __init__( + self, + url: str = "", + headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.AsyncClient] = None, + ws_url: str = "", + ws_headers: Optional[Dict[str, Any]] = None, + ws_origin: Optional[str] = None, + ws_connection_init_payload: Optional[Dict[str, Any]] = None, + ) -> None: + self.url = url + self.headers = headers + self.http_client = ( + http_client if http_client else httpx.AsyncClient(headers=headers) + ) + + self.ws_url = ws_url + self.ws_headers = ws_headers or {} + self.ws_origin = Origin(ws_origin) if ws_origin else None + self.ws_connection_init_payload = ws_connection_init_payload + + async def __aenter__(self: Self) -> Self: + return self + + async def __aexit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + await self.http_client.aclose() + + async def execute( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + operation_name=operation_name, + variables=processed_variables, + files=files, + files_map=files_map, + **kwargs, + ) + + return await self._execute_json( + query=query, + operation_name=operation_name, + variables=processed_variables, + **kwargs, + ) + + def get_data(self, response: httpx.Response) -> Dict[str, Any]: + if not response.is_success: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQLClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ( + "data" not in response_json and "errors" not in response_json + ): + raise GraphQLClientInvalidResponseError(response=response) + + data = response_json.get("data") + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(Dict[str, Any], data) + + async def execute_ws( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + **merged_kwargs, + ) as websocket: + await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + operation_name=operation_name, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + async def _execute_multipart( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + **kwargs: Any, + ) -> httpx.Response: + data = { + "operations": json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) + + async def _execute_json( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + **kwargs: Any, + ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + + return await self.http_client.post( + url=self.url, + content=json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + **merged_kwargs, + ) + + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: + payload: Dict[str, Any] = { + "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value + } + if self.ws_connection_init_payload: + payload["payload"] = self.ws_connection_init_payload + await websocket.send(json.dumps(payload)) + + async def _send_subscribe( + self, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + payload: Dict[str, Any] = { + "id": operation_id, + "type": GraphQLTransportWSMessageType.SUBSCRIBE.value, + "payload": {"query": query, "operationName": operation_name}, + } + if variables: + payload["payload"]["variables"] = self._convert_dict_to_json_serializable( + variables + ) + await websocket.send(json.dumps(payload)) + + async def _handle_ws_message( + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, + ) -> Optional[Dict[str, Any]]: + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: + raise GraphQLClientInvalidMessageFormat(message=message) + + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/base_model.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/base_model.py new file mode 100644 index 00000000..ccde3975 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/base_model.py @@ -0,0 +1,27 @@ +from io import IOBase + +from pydantic import BaseModel as PydanticBaseModel, ConfigDict + + +class UnsetType: + def __bool__(self) -> bool: + return False + + +UNSET = UnsetType() + + +class BaseModel(PydanticBaseModel): + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + arbitrary_types_allowed=True, + protected_namespaces=(), + ) + + +class Upload: + def __init__(self, filename: str, content: IOBase, content_type: str): + self.filename = filename + self.content = content + self.content_type = content_type diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/client.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/client.py new file mode 100644 index 00000000..3b388490 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/client.py @@ -0,0 +1,314 @@ +from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union + +from .async_base_client import AsyncBaseClient + +if TYPE_CHECKING: + from .client_forward_refs_shorter_resultsfragments import ( + FragmentWithSingleFieldQueryUnwrapFragment, + ) + from .custom_scalars import ComplexScalar, SimpleScalar + from .get_animal_by_name import ( + GetAnimalByNameAnimalByNameAnimal, + GetAnimalByNameAnimalByNameCat, + GetAnimalByNameAnimalByNameDog, + ) + from .get_animal_fragment_with_extra import GetAnimalFragmentWithExtra + from .get_authenticated_user import GetAuthenticatedUserMe + from .list_animals import ( + ListAnimalsListAnimalsAnimal, + ListAnimalsListAnimalsCat, + ListAnimalsListAnimalsDog, + ) + from .list_type_a import ListTypeAListOptionalTypeA + + +def gql(q: str) -> str: + return q + + +class Client(AsyncBaseClient): + async def get_authenticated_user(self, **kwargs: Any) -> "GetAuthenticatedUserMe": + from .get_authenticated_user import GetAuthenticatedUser + + query = gql( + """ + query GetAuthenticatedUser { + me { + id + username + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, + operation_name="GetAuthenticatedUser", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return GetAuthenticatedUser.model_validate(data).me + + async def list_strings_1(self, **kwargs: Any) -> Optional[List[Optional[str]]]: + from .list_strings_1 import ListStrings1 + + query = gql( + """ + query ListStrings_1 { + optionalListOptionalString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_1", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings1.model_validate(data).optional_list_optional_string + + async def list_strings_2(self, **kwargs: Any) -> Optional[List[str]]: + from .list_strings_2 import ListStrings2 + + query = gql( + """ + query ListStrings_2 { + optionalListString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_2", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings2.model_validate(data).optional_list_string + + async def list_strings_3(self, **kwargs: Any) -> List[Optional[str]]: + from .list_strings_3 import ListStrings3 + + query = gql( + """ + query ListStrings_3 { + listOptionalString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_3", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings3.model_validate(data).list_optional_string + + async def list_strings_4(self, **kwargs: Any) -> List[str]: + from .list_strings_4 import ListStrings4 + + query = gql( + """ + query ListStrings_4 { + listString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_4", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings4.model_validate(data).list_string + + async def list_type_a( + self, **kwargs: Any + ) -> List[Optional["ListTypeAListOptionalTypeA"]]: + from .list_type_a import ListTypeA + + query = gql( + """ + query ListTypeA { + listOptionalTypeA { + id + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListTypeA", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListTypeA.model_validate(data).list_optional_type_a + + async def get_animal_by_name(self, name: str, **kwargs: Any) -> Union[ + "GetAnimalByNameAnimalByNameAnimal", + "GetAnimalByNameAnimalByNameCat", + "GetAnimalByNameAnimalByNameDog", + ]: + from .get_animal_by_name import GetAnimalByName + + query = gql( + """ + query GetAnimalByName($name: String!) { + animalByName(name: $name) { + __typename + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } + } + """ + ) + variables: Dict[str, object] = {"name": name} + response = await self.execute( + query=query, operation_name="GetAnimalByName", variables=variables, **kwargs + ) + data = self.get_data(response) + return GetAnimalByName.model_validate(data).animal_by_name + + async def list_animals(self, **kwargs: Any) -> List[ + Union[ + "ListAnimalsListAnimalsAnimal", + "ListAnimalsListAnimalsCat", + "ListAnimalsListAnimalsDog", + ] + ]: + from .list_animals import ListAnimals + + query = gql( + """ + query ListAnimals { + listAnimals { + __typename + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListAnimals", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListAnimals.model_validate(data).list_animals + + async def get_animal_fragment_with_extra( + self, **kwargs: Any + ) -> "GetAnimalFragmentWithExtra": + from .get_animal_fragment_with_extra import GetAnimalFragmentWithExtra + + query = gql( + """ + query GetAnimalFragmentWithExtra { + ...ListAnimalsFragment + listString + } + + fragment ListAnimalsFragment on Query { + listAnimals { + name + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, + operation_name="GetAnimalFragmentWithExtra", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return GetAnimalFragmentWithExtra.model_validate(data) + + async def get_simple_scalar(self, **kwargs: Any) -> "SimpleScalar": + from .get_simple_scalar import GetSimpleScalar + + query = gql( + """ + query GetSimpleScalar { + justSimpleScalar + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="GetSimpleScalar", variables=variables, **kwargs + ) + data = self.get_data(response) + return GetSimpleScalar.model_validate(data).just_simple_scalar + + async def get_complex_scalar(self, **kwargs: Any) -> "ComplexScalar": + from .get_complex_scalar import GetComplexScalar + + query = gql( + """ + query GetComplexScalar { + justComplexScalar + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, + operation_name="GetComplexScalar", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return GetComplexScalar.model_validate(data).just_complex_scalar + + async def subscribe_strings( + self, **kwargs: Any + ) -> AsyncIterator[Optional[List[str]]]: + from .subscribe_strings import SubscribeStrings + + query = gql( + """ + subscription SubscribeStrings { + optionalListString + } + """ + ) + variables: Dict[str, object] = {} + async for data in self.execute_ws( + query=query, + operation_name="SubscribeStrings", + variables=variables, + **kwargs + ): + yield SubscribeStrings.model_validate(data).optional_list_string + + async def unwrap_fragment( + self, **kwargs: Any + ) -> "FragmentWithSingleFieldQueryUnwrapFragment": + from .unwrap_fragment import UnwrapFragment + + query = gql( + """ + query UnwrapFragment { + ...FragmentWithSingleField + } + + fragment FragmentWithSingleField on Query { + queryUnwrapFragment { + id + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="UnwrapFragment", variables=variables, **kwargs + ) + data = self.get_data(response) + return UnwrapFragment.model_validate(data).query_unwrap_fragment diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/client_forward_refs_shorter_resultsfragments.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/client_forward_refs_shorter_resultsfragments.py new file mode 100644 index 00000000..4b103a44 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/client_forward_refs_shorter_resultsfragments.py @@ -0,0 +1,28 @@ +from typing import List, Literal + +from pydantic import Field + +from .base_model import BaseModel + + +class FragmentWithSingleField(BaseModel): + query_unwrap_fragment: "FragmentWithSingleFieldQueryUnwrapFragment" = Field( + alias="queryUnwrapFragment" + ) + + +class FragmentWithSingleFieldQueryUnwrapFragment(BaseModel): + id: int + + +class ListAnimalsFragment(BaseModel): + list_animals: List["ListAnimalsFragmentListAnimals"] = Field(alias="listAnimals") + + +class ListAnimalsFragmentListAnimals(BaseModel): + typename__: Literal["Animal", "Cat", "Dog"] = Field(alias="__typename") + name: str + + +FragmentWithSingleField.model_rebuild() +ListAnimalsFragment.model_rebuild() diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/custom_scalars.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/custom_scalars.py new file mode 100644 index 00000000..5feb063a --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/custom_scalars.py @@ -0,0 +1,14 @@ +SimpleScalar = str + + +class ComplexScalar: + def __init__(self, value: str) -> None: + self.value = value + + +def parse_complex_scalar(value: str) -> ComplexScalar: + return ComplexScalar(value) + + +def serialize_complex_scalar(value: ComplexScalar) -> str: + return value.value diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/enums.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/enums.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/exceptions.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/exceptions.py new file mode 100644 index 00000000..b34acfe1 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/exceptions.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + + +class GraphQLClientError(Exception): + """Base exception.""" + + +class GraphQLClientHttpError(GraphQLClientError): + def __init__(self, status_code: int, response: httpx.Response) -> None: + self.status_code = status_code + self.response = response + + def __str__(self) -> str: + return f"HTTP status code: {self.status_code}" + + +class GraphQLClientInvalidResponseError(GraphQLClientError): + def __init__(self, response: httpx.Response) -> None: + self.response = response + + def __str__(self) -> str: + return "Invalid response format." + + +class GraphQLClientGraphQLError(GraphQLClientError): + def __init__( + self, + message: str, + locations: Optional[List[Dict[str, int]]] = None, + path: Optional[List[str]] = None, + extensions: Optional[Dict[str, object]] = None, + orginal: Optional[Dict[str, object]] = None, + ): + self.message = message + self.locations = locations + self.path = path + self.extensions = extensions + self.orginal = orginal + + def __str__(self) -> str: + return self.message + + @classmethod + def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": + return cls( + message=error["message"], + locations=error.get("locations"), + path=error.get("path"), + extensions=error.get("extensions"), + orginal=error, + ) + + +class GraphQLClientGraphQLMultiError(GraphQLClientError): + def __init__( + self, + errors: List[GraphQLClientGraphQLError], + data: Optional[Dict[str, Any]] = None, + ): + self.errors = errors + self.data = data + + def __str__(self) -> str: + return "; ".join(str(e) for e in self.errors) + + @classmethod + def from_errors_dicts( + cls, errors_dicts: List[Dict[str, Any]], data: Optional[Dict[str, Any]] = None + ) -> "GraphQLClientGraphQLMultiError": + return cls( + errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts], + data=data, + ) + + +class GraphQLClientInvalidMessageFormat(GraphQLClientError): + def __init__(self, message: Union[str, bytes]) -> None: + self.message = message + + def __str__(self) -> str: + return "Invalid message format." diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_animal_by_name.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_animal_by_name.py new file mode 100644 index 00000000..e97e12ac --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_animal_by_name.py @@ -0,0 +1,33 @@ +from typing import Literal, Union + +from pydantic import Field + +from .base_model import BaseModel + + +class GetAnimalByName(BaseModel): + animal_by_name: Union[ + "GetAnimalByNameAnimalByNameAnimal", + "GetAnimalByNameAnimalByNameCat", + "GetAnimalByNameAnimalByNameDog", + ] = Field(alias="animalByName", discriminator="typename__") + + +class GetAnimalByNameAnimalByNameAnimal(BaseModel): + typename__: Literal["Animal"] = Field(alias="__typename") + name: str + + +class GetAnimalByNameAnimalByNameCat(BaseModel): + typename__: Literal["Cat"] = Field(alias="__typename") + name: str + kittens: int + + +class GetAnimalByNameAnimalByNameDog(BaseModel): + typename__: Literal["Dog"] = Field(alias="__typename") + name: str + puppies: int + + +GetAnimalByName.model_rebuild() diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_animal_fragment_with_extra.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_animal_fragment_with_extra.py new file mode 100644 index 00000000..2c56da12 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_animal_fragment_with_extra.py @@ -0,0 +1,9 @@ +from typing import List + +from pydantic import Field + +from .client_forward_refs_shorter_resultsfragments import ListAnimalsFragment + + +class GetAnimalFragmentWithExtra(ListAnimalsFragment): + list_string: List[str] = Field(alias="listString") diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_authenticated_user.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_authenticated_user.py new file mode 100644 index 00000000..ce0e2cf7 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_authenticated_user.py @@ -0,0 +1,13 @@ +from .base_model import BaseModel + + +class GetAuthenticatedUser(BaseModel): + me: "GetAuthenticatedUserMe" + + +class GetAuthenticatedUserMe(BaseModel): + id: str + username: str + + +GetAuthenticatedUser.model_rebuild() diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_complex_scalar.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_complex_scalar.py new file mode 100644 index 00000000..78563772 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_complex_scalar.py @@ -0,0 +1,12 @@ +from typing import Annotated + +from pydantic import BeforeValidator, Field + +from .base_model import BaseModel +from .custom_scalars import ComplexScalar, parse_complex_scalar + + +class GetComplexScalar(BaseModel): + just_complex_scalar: Annotated[ + ComplexScalar, BeforeValidator(parse_complex_scalar) + ] = Field(alias="justComplexScalar") diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_simple_scalar.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_simple_scalar.py new file mode 100644 index 00000000..365f7eaa --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_simple_scalar.py @@ -0,0 +1,8 @@ +from pydantic import Field + +from .base_model import BaseModel +from .custom_scalars import SimpleScalar + + +class GetSimpleScalar(BaseModel): + just_simple_scalar: SimpleScalar = Field(alias="justSimpleScalar") diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/input_types.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/input_types.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_animals.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_animals.py new file mode 100644 index 00000000..bcab1c45 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_animals.py @@ -0,0 +1,38 @@ +from typing import Annotated, List, Literal, Union + +from pydantic import Field + +from .base_model import BaseModel + + +class ListAnimals(BaseModel): + list_animals: List[ + Annotated[ + Union[ + "ListAnimalsListAnimalsAnimal", + "ListAnimalsListAnimalsCat", + "ListAnimalsListAnimalsDog", + ], + Field(discriminator="typename__"), + ] + ] = Field(alias="listAnimals") + + +class ListAnimalsListAnimalsAnimal(BaseModel): + typename__: Literal["Animal"] = Field(alias="__typename") + name: str + + +class ListAnimalsListAnimalsCat(BaseModel): + typename__: Literal["Cat"] = Field(alias="__typename") + name: str + kittens: int + + +class ListAnimalsListAnimalsDog(BaseModel): + typename__: Literal["Dog"] = Field(alias="__typename") + name: str + puppies: int + + +ListAnimals.model_rebuild() diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_1.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_1.py new file mode 100644 index 00000000..fd8c06de --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_1.py @@ -0,0 +1,11 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings1(BaseModel): + optional_list_optional_string: Optional[List[Optional[str]]] = Field( + alias="optionalListOptionalString" + ) diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_2.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_2.py new file mode 100644 index 00000000..d91ec117 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_2.py @@ -0,0 +1,9 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings2(BaseModel): + optional_list_string: Optional[List[str]] = Field(alias="optionalListString") diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_3.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_3.py new file mode 100644 index 00000000..88f6e2cf --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_3.py @@ -0,0 +1,9 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings3(BaseModel): + list_optional_string: List[Optional[str]] = Field(alias="listOptionalString") diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_4.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_4.py new file mode 100644 index 00000000..15872b23 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_4.py @@ -0,0 +1,9 @@ +from typing import List + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings4(BaseModel): + list_string: List[str] = Field(alias="listString") diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_type_a.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_type_a.py new file mode 100644 index 00000000..9e2c8f04 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_type_a.py @@ -0,0 +1,18 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListTypeA(BaseModel): + list_optional_type_a: List[Optional["ListTypeAListOptionalTypeA"]] = Field( + alias="listOptionalTypeA" + ) + + +class ListTypeAListOptionalTypeA(BaseModel): + id: int + + +ListTypeA.model_rebuild() diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/subscribe_strings.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/subscribe_strings.py new file mode 100644 index 00000000..a1a42d4f --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/subscribe_strings.py @@ -0,0 +1,9 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class SubscribeStrings(BaseModel): + optional_list_string: Optional[List[str]] = Field(alias="optionalListString") diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/unwrap_fragment.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/unwrap_fragment.py new file mode 100644 index 00000000..e3512512 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/unwrap_fragment.py @@ -0,0 +1,5 @@ +from .client_forward_refs_shorter_resultsfragments import FragmentWithSingleField + + +class UnwrapFragment(FragmentWithSingleField): + pass diff --git a/tests/main/clients/client_forward_refs_shorter_results/pyproject.toml b/tests/main/clients/client_forward_refs_shorter_results/pyproject.toml new file mode 100644 index 00000000..1bc91372 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/pyproject.toml @@ -0,0 +1,19 @@ +[tool.ariadne-codegen] +schema_path = "schema.graphql" +queries_path = "queries.graphql" +include_comments = "none" +target_package_name = "client_forward_refs_shorter_results" +files_to_include = ["custom_scalars.py"] +fragments_module_name = "client_forward_refs_shorter_resultsfragments" +plugins = [ + "ariadne_codegen.contrib.shorter_results.ShorterResultsPlugin", + "ariadne_codegen.contrib.client_forward_refs.ClientForwardRefsPlugin" +] + +[tool.ariadne-codegen.scalars.SimpleScalar] +type = ".custom_scalars.SimpleScalar" + +[tool.ariadne-codegen.scalars.ComplexScalar] +type = ".custom_scalars.ComplexScalar" +parse = ".custom_scalars.parse_complex_scalar" +serialize = ".custom_scalars.serialize_complex_scalar" diff --git a/tests/main/clients/client_forward_refs_shorter_results/queries.graphql b/tests/main/clients/client_forward_refs_shorter_results/queries.graphql new file mode 100644 index 00000000..d407732c --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/queries.graphql @@ -0,0 +1,86 @@ +query GetAuthenticatedUser { + me { + id + username + } +} + +query ListStrings_1 { + optionalListOptionalString +} + +query ListStrings_2 { + optionalListString +} + +query ListStrings_3 { + listOptionalString +} + +query ListStrings_4 { + listString +} + +query ListTypeA { + listOptionalTypeA { + id + } +} + +query GetAnimalByName($name: String!) { + animalByName(name: $name) { + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } +} + +query ListAnimals { + listAnimals { + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } +} + +query GetAnimalFragmentWithExtra { + ...ListAnimalsFragment + listString +} + +query GetSimpleScalar { + justSimpleScalar +} + +query GetComplexScalar { + justComplexScalar +} + +fragment ListAnimalsFragment on Query { + listAnimals { + name + } +} + +subscription SubscribeStrings { + optionalListString +} + + +fragment FragmentWithSingleField on Query { + queryUnwrapFragment { + id + } +} + +query UnwrapFragment { + ...FragmentWithSingleField +} diff --git a/tests/main/clients/client_forward_refs_shorter_results/schema.graphql b/tests/main/clients/client_forward_refs_shorter_results/schema.graphql new file mode 100644 index 00000000..7acc7f92 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/schema.graphql @@ -0,0 +1,48 @@ +scalar SimpleScalar +scalar ComplexScalar + +type TypeA { + id: Int! +} + +type User { + id: ID! + username: String! +} + +interface Animal { + name: String! +} + +type Cat implements Animal { + name: String! + kittens: Int! +} + +type Dog implements Animal { + name: String! + puppies: Int! +} + +type Query { + optionalListOptionalString: [String] + optionalListString: [String!] + listOptionalString: [String]! + listString: [String!]! + + listOptionalTypeA: [TypeA]! + + me: User! + + listAnimals: [Animal!]! + animalByName(name: String!): Animal! + + justSimpleScalar: SimpleScalar! + justComplexScalar: ComplexScalar! + + queryUnwrapFragment: TypeA! +} + +type Subscription { + optionalListString: [String!] +}