From 794743bd0ad0ecaab53f0f8f6ab5f81fb59d69e9 Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Wed, 13 Mar 2024 19:00:26 +0100 Subject: [PATCH 01/11] Add plugin to reduce client imports * Convert input types and return types to `ast.Constant` * Import `TYPE_CHECKING` and import types if set * Import required types inside each method * Add tests * Update CHANGELOG * Update README --- CHANGELOG.md | 1 + README.md | 2 + ariadne_codegen/contrib/no_global_imports.py | 359 +++++++++++++++++ .../no_global_imports/custom_scalars.py | 14 + .../expected_client/__init__.py | 76 ++++ .../expected_client/async_base_client.py | 370 ++++++++++++++++++ .../expected_client/base_model.py | 27 ++ .../expected_client/client.py | 296 ++++++++++++++ .../expected_client/custom_scalars.py | 14 + .../expected_client/enums.py | 0 .../expected_client/exceptions.py | 83 ++++ .../expected_client/get_animal_by_name.py | 33 ++ .../get_animal_fragment_with_extra.py | 9 + .../expected_client/get_authenticated_user.py | 13 + .../expected_client/get_complex_scalar.py | 12 + .../expected_client/get_simple_scalar.py | 8 + .../expected_client/input_types.py | 0 .../expected_client/list_animals.py | 38 ++ .../expected_client/list_strings_1.py | 11 + .../expected_client/list_strings_2.py | 9 + .../expected_client/list_strings_3.py | 9 + .../expected_client/list_strings_4.py | 9 + .../expected_client/list_type_a.py | 18 + .../no_global_imports_fragments.py | 28 ++ .../expected_client/subscribe_strings.py | 9 + .../expected_client/unwrap_fragment.py | 5 + .../clients/no_global_imports/pyproject.toml | 16 + .../clients/no_global_imports/queries.graphql | 86 ++++ .../clients/no_global_imports/schema.graphql | 48 +++ .../custom_scalars.py | 14 + .../expected_client/__init__.py | 76 ++++ .../expected_client/async_base_client.py | 370 ++++++++++++++++++ .../expected_client/base_model.py | 27 ++ .../expected_client/client.py | 314 +++++++++++++++ .../expected_client/custom_scalars.py | 14 + .../expected_client/enums.py | 0 .../expected_client/exceptions.py | 83 ++++ .../expected_client/get_animal_by_name.py | 33 ++ .../get_animal_fragment_with_extra.py | 9 + .../expected_client/get_authenticated_user.py | 13 + .../expected_client/get_complex_scalar.py | 12 + .../expected_client/get_simple_scalar.py | 8 + .../expected_client/input_types.py | 0 .../expected_client/list_animals.py | 38 ++ .../expected_client/list_strings_1.py | 11 + .../expected_client/list_strings_2.py | 9 + .../expected_client/list_strings_3.py | 9 + .../expected_client/list_strings_4.py | 9 + .../expected_client/list_type_a.py | 18 + ...global_imports_shorter_resultsfragments.py | 28 ++ .../expected_client/subscribe_strings.py | 9 + .../expected_client/unwrap_fragment.py | 5 + .../pyproject.toml | 19 + .../queries.graphql | 86 ++++ .../schema.graphql | 48 +++ 55 files changed, 2865 insertions(+) create mode 100644 ariadne_codegen/contrib/no_global_imports.py create mode 100644 tests/main/clients/no_global_imports/custom_scalars.py create mode 100644 tests/main/clients/no_global_imports/expected_client/__init__.py create mode 100644 tests/main/clients/no_global_imports/expected_client/async_base_client.py create mode 100644 tests/main/clients/no_global_imports/expected_client/base_model.py create mode 100644 tests/main/clients/no_global_imports/expected_client/client.py create mode 100644 tests/main/clients/no_global_imports/expected_client/custom_scalars.py create mode 100644 tests/main/clients/no_global_imports/expected_client/enums.py create mode 100644 tests/main/clients/no_global_imports/expected_client/exceptions.py create mode 100644 tests/main/clients/no_global_imports/expected_client/get_animal_by_name.py create mode 100644 tests/main/clients/no_global_imports/expected_client/get_animal_fragment_with_extra.py create mode 100644 tests/main/clients/no_global_imports/expected_client/get_authenticated_user.py create mode 100644 tests/main/clients/no_global_imports/expected_client/get_complex_scalar.py create mode 100644 tests/main/clients/no_global_imports/expected_client/get_simple_scalar.py create mode 100644 tests/main/clients/no_global_imports/expected_client/input_types.py create mode 100644 tests/main/clients/no_global_imports/expected_client/list_animals.py create mode 100644 tests/main/clients/no_global_imports/expected_client/list_strings_1.py create mode 100644 tests/main/clients/no_global_imports/expected_client/list_strings_2.py create mode 100644 tests/main/clients/no_global_imports/expected_client/list_strings_3.py create mode 100644 tests/main/clients/no_global_imports/expected_client/list_strings_4.py create mode 100644 tests/main/clients/no_global_imports/expected_client/list_type_a.py create mode 100644 tests/main/clients/no_global_imports/expected_client/no_global_imports_fragments.py create mode 100644 tests/main/clients/no_global_imports/expected_client/subscribe_strings.py create mode 100644 tests/main/clients/no_global_imports/expected_client/unwrap_fragment.py create mode 100644 tests/main/clients/no_global_imports/pyproject.toml create mode 100644 tests/main/clients/no_global_imports/queries.graphql create mode 100644 tests/main/clients/no_global_imports/schema.graphql create mode 100644 tests/main/clients/no_global_imports_shorter_results/custom_scalars.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/__init__.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/async_base_client.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/base_model.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/client.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/custom_scalars.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/enums.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/exceptions.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/get_animal_by_name.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/get_animal_fragment_with_extra.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/get_authenticated_user.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/get_complex_scalar.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/get_simple_scalar.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/input_types.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/list_animals.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_1.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_2.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_3.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_4.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/list_type_a.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/no_global_imports_shorter_resultsfragments.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/subscribe_strings.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/unwrap_fragment.py create mode 100644 tests/main/clients/no_global_imports_shorter_results/pyproject.toml create mode 100644 tests/main/clients/no_global_imports_shorter_results/queries.graphql create mode 100644 tests/main/clients/no_global_imports_shorter_results/schema.graphql diff --git a/CHANGELOG.md b/CHANGELOG.md index ee50112b..3b20daa1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## 0.14.0 (Unreleased) +- Added `NoGlobalImportsPlugin` to standard plugins. - Re-added `model_rebuild` calls for input types with forward references. diff --git a/README.md b/README.md index 999748c3..2c570295 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,8 @@ Ariadne Codegen ships with optional plugins importable from the `ariadne_codegen - [`ariadne_codegen.contrib.extract_operations.ExtractOperationsPlugin`](ariadne_codegen/contrib/extract_operations.py) - This extracts query strings from generated client's methods into separate `operations.py` module. It also modifies the generated client to import these definitions. Generated module name can be customized by adding `operations_module_name="custom_name"` to the `[tool.ariadne-codegen.operations]` section in config. Eg.: +- [`ariadne_codegen.contrib.no_global_imports.NoGlobalImportsPlugin`](ariadne_codegen/contrib/no_global_imports.py) - This plugin processes generated client module and convert all input arguments and return types to strings. The types will be imported only for type checking. + ```toml [tool.ariadne-codegen] ... diff --git a/ariadne_codegen/contrib/no_global_imports.py b/ariadne_codegen/contrib/no_global_imports.py new file mode 100644 index 00000000..eb69e40e --- /dev/null +++ b/ariadne_codegen/contrib/no_global_imports.py @@ -0,0 +1,359 @@ +""" +Plugin to only import types when you call methods + +This will massively reduce import times for larger projects since you only have +to load the input types when loading the client. + +All result types that's used to process the server response will only be +imported when the method is called. +""" + +import ast +from typing import Dict + +from graphql import GraphQLSchema + +from ariadne_codegen import Plugin + + +class NoGlobalImportsPlugin(Plugin): + """Only import types when you call an endpoint needing it""" + + def __init__(self, schema: GraphQLSchema, config_dict: Dict) -> None: + """Constructor""" + # Types that should only be imported in a `TYPE_CHECKING` context. This + # is all the types used as arguments to a method or as a return type, + # i.e. for type checking. + self.input_and_return_types: set[str] = set() + + # Imported classes are classes imported from local imports. We keep a + # map between name and module so we know how to import them in each + # method. + self.imported_classes: dict[str, str] = {} + + # Imported classes in each method definition. + self.imported_in_method: set[str] = set() + + super().__init__(schema, config_dict) + + def generate_client_module(self, module: ast.Module) -> ast.Module: + """ + Update the generated client. + + This will parse all current imports to map them to a path. It will then + traverse all methods and look for the actual return type. The return + node will be converted to an `ast.Constant` if it's an `ast.Name` and + the return type will be imported only under `if TYPE_CHECKING` + conditions. + + It will also move all imports of the types used to parse the response + inside each method since that's the only place where they're used. The + result will be that we end up with imports in the global scope only for + types used as input types. + + :param module: The ast for the module + """ + self._store_imported_classes(module.body) + + # Find the actual client class so we can grab all input and output + # types. We also ensure to manipulate the ast while we do this. + client_class_def = next( + filter(lambda o: isinstance(o, ast.ClassDef), module.body), None + ) + if not client_class_def or not isinstance(client_class_def, ast.ClassDef): + return super().generate_client_module(module) + + for method_def in [ + m + for m in client_class_def.body + if isinstance(m, (ast.FunctionDef, ast.AsyncFunctionDef)) + ]: + method_def = self._rewrite_input_args_to_constants(method_def) + + # If the method returns anything, update whatever it returns. + if method_def.returns: + method_def.returns = self._update_name_to_constant(method_def.returns) + + self._insert_import_statement_in_method(method_def) + + self._update_imports(module) + + return super().generate_client_module(module) + + def _store_imported_classes(self, module_body: list[ast.stmt]): + """Fetch and store imported classes. + + Grab all imported classes with level 1 or starting with `.` because + these are the ones generated by us. We store a map between the class and + which module it was imported from so we can easily import it when + needed. This can be in a `TYPE_CHECKING` condition or inside a method. + + :param module_body: The body of an `ast.Module` + """ + for node in module_body: + if not isinstance(node, ast.ImportFrom): + continue + + if node.module is None: + continue + + # We only care about local imports from our generated code. + if node.level != 1 and not node.module.startswith("."): + continue + + for name in node.names: + from_ = "." * node.level + node.module + if isinstance(name, ast.alias): + self.imported_classes[name.name] = from_ + + def _rewrite_input_args_to_constants( + self, method_def: ast.FunctionDef | ast.AsyncFunctionDef + ) -> ast.FunctionDef | ast.AsyncFunctionDef: + """Rewrite the arguments to a method. + + For any `ast.Name` that requires an import convert it to an + `ast.Constant` instead. The actual class will be noted and imported + in a `TYPE_CHECKING` context. + + :param method_def: Method definition + :returns: The same definition but updated + """ + if not isinstance(method_def, (ast.FunctionDef, ast.AsyncFunctionDef)): + return method_def + + for i, input_arg in enumerate(method_def.args.args): + annotation = input_arg.annotation + if isinstance(annotation, (ast.Name, ast.Subscript, ast.Tuple)): + method_def.args.args[i].annotation = self._update_name_to_constant( + annotation + ) + + return method_def + + def _insert_import_statement_in_method( + self, method_def: ast.FunctionDef | ast.AsyncFunctionDef + ): + """Insert import statement in method. + + Each method will eventually pass the returned value to a class we've + generated. Since we only need it in the scope of the method ensure we + add it at the top of the method only. It will be removed from the global + scope. + + :param method_def: The method definition to updated + """ + # Find the last statement in the body, the call to this class is + # what we need to import first. + return_stmt = method_def.body[-1] + if isinstance(return_stmt, ast.Return): + call = self._get_call_arg_from_return(return_stmt) + elif isinstance(return_stmt, ast.AsyncFor): + call = self._get_call_arg_from_async_for(return_stmt) + else: + return + + if call is None: + return + + import_class = self._get_class_from_call(call) + if import_class is None: + return + + import_class_id = import_class.id + + # We add the class to our set of imported in methods - these classes + # don't need to be imported at all in the global scope. + self.imported_in_method.add(import_class.id) + method_def.body.insert( + 0, + ast.ImportFrom( + module=self.imported_classes[import_class_id], + names=[import_class], + ), + ) + + def _get_call_arg_from_return(self, return_stmt: ast.Return) -> ast.Call | None: + """Get the class used in the return statement. + + :param return_stmt: The statement used for return + """ + # If it's a call of the class like produced by + # `ShorterResultsPlugin` we have an attribute. + if isinstance(return_stmt.value, ast.Attribute) and isinstance( + return_stmt.value.value, ast.Call + ): + return return_stmt.value.value + + # If not it's just a call statement to the generated class. + if isinstance(return_stmt.value, ast.Call): + return return_stmt.value + + return None + + def _get_call_arg_from_async_for(self, last_stmt: ast.AsyncFor) -> ast.Call | None: + """Get the class used in the yield expression. + + :param last_stmt: The statement used in `ast.AsyncFor` + """ + if isinstance(last_stmt.body, list) and isinstance(last_stmt.body[0], ast.Expr): + body = last_stmt.body[0] + elif isinstance(last_stmt.body, ast.Expr): + body = last_stmt.body + else: + return None + + if not isinstance(body, ast.Expr): + return None + + if not isinstance(body.value, ast.Yield): + return None + + # If it's a call of the class like produced by + # `ShorterResultsPlugin` we have an attribute. + if isinstance(body.value.value, ast.Attribute) and isinstance( + body.value.value.value, ast.Call + ): + return body.value.value.value + + # If not it's just a call statement to the generated class. + if isinstance(body.value.value, ast.Call): + return body.value.value + + return None + + def _get_class_from_call(self, call: ast.Call) -> ast.Name | None: + """Get the class from an `ast.Call`. + + :param call: The `ast.Call` arg + :returns: `ast.Name` or `None` + """ + if not isinstance(call.func, ast.Attribute): + return None + + if not isinstance(call.func.value, ast.Name): + return None + + return call.func.value + + def _update_imports(self, module: ast.Module) -> ast.Name | None: + """Update all imports. + + Iterate over all imports and remove the aliases that we use as input or + return value. These will be moved and added to an `if TYPE_CHECKING` + block. + + **NOTE** If an `ast.ImportFrom` ends up without any names we must remove + it completely otherwise formatting will not work (it would remove the + empty `import from` but not format the rest of the code without running + it twice). + + We do this by storing all imports that we want to keep in an array, we + then drop all from the body and re-insert the ones to keep. Lastly we + import `TYPE_CHECKING` and add all our imports in the `if TYPE_CHECKING` + block. + + :param module: The ast for the whole module. + """ + # We now know all our input types and all our return types. The return + # types that are _not_ used as import types should be in an `if + # TYPE_CHECKING` import block. + return_types_not_used_as_input = set(self.input_and_return_types) + + # The ones we import in the method don't need to be imported at all - + # unless that's the type we return. This behaviour can differ if you use + # a plugin such as `ShorterResultsPlugin` that will import a type that + # is different from the type returned. + return_types_not_used_as_input.update( + {k for k in self.imported_in_method if k not in self.input_and_return_types} + ) + + if len(return_types_not_used_as_input) == 0: + return None + + # We sadly have to iterate over all imports again and remove the imports + # we will do conditionally. + # It's very important that we get this right, if we keep any + # `ImportFrom` that ends up without any names, the formatting will not + # work! It will only remove the empty `import from` but not other unused + # imports. + non_empty_imports: list[ast.Import | ast.ImportFrom] = [] + last_import_at = 0 + for i, node in enumerate(module.body): + if isinstance(node, ast.Import): + last_import_at = i + non_empty_imports.append(node) + + if not isinstance(node, ast.ImportFrom): + continue + + last_import_at = i + reduced_names = [] + for name in node.names: + if name.name not in return_types_not_used_as_input: + reduced_names.append(name) + + node.names = reduced_names + + if len(reduced_names) > 0: + non_empty_imports.append(node) + + # We can now remove all imports and re-insert the ones that's not empty. + module.body = non_empty_imports + module.body[last_import_at + 1 :] + + # Create import to use for type checking. These will be put in an `if + # TYPE_CHECKING` block. + type_checking_imports = {} + for cls in self.input_and_return_types: + module_name = self.imported_classes[cls] + if module_name not in type_checking_imports: + type_checking_imports[module_name] = ast.ImportFrom( + module=module_name, names=[] + ) + + type_checking_imports[module_name].names.append(ast.alias(cls)) + + import_if_type_checking = ast.If( + test=ast.Name(id="TYPE_CHECKING"), + body=list(type_checking_imports.values()), + orelse=[], + ) + + module.body.insert(len(non_empty_imports), import_if_type_checking) + + # Import `TYPE_CHECKING`. + module.body.insert( + len(non_empty_imports), + ast.ImportFrom( + module="typing", + names=[ast.Name("TYPE_CHECKING")], + ), + ) + + return None + + def _update_name_to_constant(self, node: ast.expr) -> ast.expr: + """Update return types. + + If the return type contains any type that resolves to an `ast.Name`, + convert it to an `ast.Constant`. We only need the type for type checking + and can avoid importing the type in the global scope unless needed. + + :param node: The ast node used as return type + :returns: A modified ast node + """ + if isinstance(node, ast.Name): + if node.id in self.imported_classes: + self.input_and_return_types.add(node.id) + return ast.Constant(value=node.id) + + if isinstance(node, ast.Subscript): + node.slice = self._update_name_to_constant(node.slice) + return node + + if isinstance(node, ast.Tuple): + for i, _ in enumerate(node.elts): + node.elts[i] = self._update_name_to_constant(node.elts[i]) + + return node + + return node diff --git a/tests/main/clients/no_global_imports/custom_scalars.py b/tests/main/clients/no_global_imports/custom_scalars.py new file mode 100644 index 00000000..5feb063a --- /dev/null +++ b/tests/main/clients/no_global_imports/custom_scalars.py @@ -0,0 +1,14 @@ +SimpleScalar = str + + +class ComplexScalar: + def __init__(self, value: str) -> None: + self.value = value + + +def parse_complex_scalar(value: str) -> ComplexScalar: + return ComplexScalar(value) + + +def serialize_complex_scalar(value: ComplexScalar) -> str: + return value.value diff --git a/tests/main/clients/no_global_imports/expected_client/__init__.py b/tests/main/clients/no_global_imports/expected_client/__init__.py new file mode 100644 index 00000000..41b2df92 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/__init__.py @@ -0,0 +1,76 @@ +from .async_base_client import AsyncBaseClient +from .base_model import BaseModel, Upload +from .client import Client +from .exceptions import ( + GraphQLClientError, + GraphQLClientGraphQLError, + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidResponseError, +) +from .get_animal_by_name import ( + GetAnimalByName, + GetAnimalByNameAnimalByNameAnimal, + GetAnimalByNameAnimalByNameCat, + GetAnimalByNameAnimalByNameDog, +) +from .get_animal_fragment_with_extra import GetAnimalFragmentWithExtra +from .get_authenticated_user import GetAuthenticatedUser, GetAuthenticatedUserMe +from .get_complex_scalar import GetComplexScalar +from .get_simple_scalar import GetSimpleScalar +from .list_animals import ( + ListAnimals, + ListAnimalsListAnimalsAnimal, + ListAnimalsListAnimalsCat, + ListAnimalsListAnimalsDog, +) +from .list_strings_1 import ListStrings1 +from .list_strings_2 import ListStrings2 +from .list_strings_3 import ListStrings3 +from .list_strings_4 import ListStrings4 +from .list_type_a import ListTypeA, ListTypeAListOptionalTypeA +from .no_global_imports_fragments import ( + FragmentWithSingleField, + FragmentWithSingleFieldQueryUnwrapFragment, + ListAnimalsFragment, + ListAnimalsFragmentListAnimals, +) +from .subscribe_strings import SubscribeStrings +from .unwrap_fragment import UnwrapFragment + +__all__ = [ + "AsyncBaseClient", + "BaseModel", + "Client", + "FragmentWithSingleField", + "FragmentWithSingleFieldQueryUnwrapFragment", + "GetAnimalByName", + "GetAnimalByNameAnimalByNameAnimal", + "GetAnimalByNameAnimalByNameCat", + "GetAnimalByNameAnimalByNameDog", + "GetAnimalFragmentWithExtra", + "GetAuthenticatedUser", + "GetAuthenticatedUserMe", + "GetComplexScalar", + "GetSimpleScalar", + "GraphQLClientError", + "GraphQLClientGraphQLError", + "GraphQLClientGraphQLMultiError", + "GraphQLClientHttpError", + "GraphQLClientInvalidResponseError", + "ListAnimals", + "ListAnimalsFragment", + "ListAnimalsFragmentListAnimals", + "ListAnimalsListAnimalsAnimal", + "ListAnimalsListAnimalsCat", + "ListAnimalsListAnimalsDog", + "ListStrings1", + "ListStrings2", + "ListStrings3", + "ListStrings4", + "ListTypeA", + "ListTypeAListOptionalTypeA", + "SubscribeStrings", + "UnwrapFragment", + "Upload", +] diff --git a/tests/main/clients/no_global_imports/expected_client/async_base_client.py b/tests/main/clients/no_global_imports/expected_client/async_base_client.py new file mode 100644 index 00000000..5358ced6 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/async_base_client.py @@ -0,0 +1,370 @@ +import enum +import json +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from uuid import uuid4 + +import httpx +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidMessageFormat, + GraphQLClientInvalidResponseError, +) + +try: + from websockets.client import ( # type: ignore[import-not-found,unused-ignore] + WebSocketClientProtocol, + connect as ws_connect, + ) + from websockets.typing import ( # type: ignore[import-not-found,unused-ignore] + Data, + Origin, + Subprotocol, + ) +except ImportError: + from contextlib import asynccontextmanager + + @asynccontextmanager # type: ignore + async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument + raise NotImplementedError("Subscriptions require 'websockets' package.") + yield # pylint: disable=unreachable + + WebSocketClientProtocol = Any # type: ignore[misc,assignment,unused-ignore] + Data = Any # type: ignore[misc,assignment,unused-ignore] + Origin = Any # type: ignore[misc,assignment,unused-ignore] + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") + + +Self = TypeVar("Self", bound="AsyncBaseClient") + +GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" + + +class GraphQLTransportWSMessageType(str, enum.Enum): + CONNECTION_INIT = "connection_init" + CONNECTION_ACK = "connection_ack" + PING = "ping" + PONG = "pong" + SUBSCRIBE = "subscribe" + NEXT = "next" + ERROR = "error" + COMPLETE = "complete" + + +class AsyncBaseClient: + def __init__( + self, + url: str = "", + headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.AsyncClient] = None, + ws_url: str = "", + ws_headers: Optional[Dict[str, Any]] = None, + ws_origin: Optional[str] = None, + ws_connection_init_payload: Optional[Dict[str, Any]] = None, + ) -> None: + self.url = url + self.headers = headers + self.http_client = ( + http_client if http_client else httpx.AsyncClient(headers=headers) + ) + + self.ws_url = ws_url + self.ws_headers = ws_headers or {} + self.ws_origin = Origin(ws_origin) if ws_origin else None + self.ws_connection_init_payload = ws_connection_init_payload + + async def __aenter__(self: Self) -> Self: + return self + + async def __aexit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + await self.http_client.aclose() + + async def execute( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + operation_name=operation_name, + variables=processed_variables, + files=files, + files_map=files_map, + **kwargs, + ) + + return await self._execute_json( + query=query, + operation_name=operation_name, + variables=processed_variables, + **kwargs, + ) + + def get_data(self, response: httpx.Response) -> Dict[str, Any]: + if not response.is_success: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQLClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ( + "data" not in response_json and "errors" not in response_json + ): + raise GraphQLClientInvalidResponseError(response=response) + + data = response_json.get("data") + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(Dict[str, Any], data) + + async def execute_ws( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + **merged_kwargs, + ) as websocket: + await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + operation_name=operation_name, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + async def _execute_multipart( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + **kwargs: Any, + ) -> httpx.Response: + data = { + "operations": json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) + + async def _execute_json( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + **kwargs: Any, + ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + + return await self.http_client.post( + url=self.url, + content=json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + **merged_kwargs, + ) + + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: + payload: Dict[str, Any] = { + "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value + } + if self.ws_connection_init_payload: + payload["payload"] = self.ws_connection_init_payload + await websocket.send(json.dumps(payload)) + + async def _send_subscribe( + self, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + payload: Dict[str, Any] = { + "id": operation_id, + "type": GraphQLTransportWSMessageType.SUBSCRIBE.value, + "payload": {"query": query, "operationName": operation_name}, + } + if variables: + payload["payload"]["variables"] = self._convert_dict_to_json_serializable( + variables + ) + await websocket.send(json.dumps(payload)) + + async def _handle_ws_message( + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, + ) -> Optional[Dict[str, Any]]: + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: + raise GraphQLClientInvalidMessageFormat(message=message) + + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/no_global_imports/expected_client/base_model.py b/tests/main/clients/no_global_imports/expected_client/base_model.py new file mode 100644 index 00000000..ccde3975 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/base_model.py @@ -0,0 +1,27 @@ +from io import IOBase + +from pydantic import BaseModel as PydanticBaseModel, ConfigDict + + +class UnsetType: + def __bool__(self) -> bool: + return False + + +UNSET = UnsetType() + + +class BaseModel(PydanticBaseModel): + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + arbitrary_types_allowed=True, + protected_namespaces=(), + ) + + +class Upload: + def __init__(self, filename: str, content: IOBase, content_type: str): + self.filename = filename + self.content = content + self.content_type = content_type diff --git a/tests/main/clients/no_global_imports/expected_client/client.py b/tests/main/clients/no_global_imports/expected_client/client.py new file mode 100644 index 00000000..9e80af45 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/client.py @@ -0,0 +1,296 @@ +from typing import TYPE_CHECKING, Any, AsyncIterator, Dict + +from .async_base_client import AsyncBaseClient + +if TYPE_CHECKING: + from .get_animal_by_name import GetAnimalByName + from .get_animal_fragment_with_extra import GetAnimalFragmentWithExtra + from .get_authenticated_user import GetAuthenticatedUser + from .get_complex_scalar import GetComplexScalar + from .get_simple_scalar import GetSimpleScalar + from .list_animals import ListAnimals + from .list_strings_1 import ListStrings1 + from .list_strings_2 import ListStrings2 + from .list_strings_3 import ListStrings3 + from .list_strings_4 import ListStrings4 + from .list_type_a import ListTypeA + from .subscribe_strings import SubscribeStrings + from .unwrap_fragment import UnwrapFragment + + +def gql(q: str) -> str: + return q + + +class Client(AsyncBaseClient): + async def get_authenticated_user(self, **kwargs: Any) -> "GetAuthenticatedUser": + from .get_authenticated_user import GetAuthenticatedUser + + query = gql( + """ + query GetAuthenticatedUser { + me { + id + username + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, + operation_name="GetAuthenticatedUser", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return GetAuthenticatedUser.model_validate(data) + + async def list_strings_1(self, **kwargs: Any) -> "ListStrings1": + from .list_strings_1 import ListStrings1 + + query = gql( + """ + query ListStrings_1 { + optionalListOptionalString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_1", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings1.model_validate(data) + + async def list_strings_2(self, **kwargs: Any) -> "ListStrings2": + from .list_strings_2 import ListStrings2 + + query = gql( + """ + query ListStrings_2 { + optionalListString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_2", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings2.model_validate(data) + + async def list_strings_3(self, **kwargs: Any) -> "ListStrings3": + from .list_strings_3 import ListStrings3 + + query = gql( + """ + query ListStrings_3 { + listOptionalString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_3", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings3.model_validate(data) + + async def list_strings_4(self, **kwargs: Any) -> "ListStrings4": + from .list_strings_4 import ListStrings4 + + query = gql( + """ + query ListStrings_4 { + listString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_4", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings4.model_validate(data) + + async def list_type_a(self, **kwargs: Any) -> "ListTypeA": + from .list_type_a import ListTypeA + + query = gql( + """ + query ListTypeA { + listOptionalTypeA { + id + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListTypeA", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListTypeA.model_validate(data) + + async def get_animal_by_name(self, name: str, **kwargs: Any) -> "GetAnimalByName": + from .get_animal_by_name import GetAnimalByName + + query = gql( + """ + query GetAnimalByName($name: String!) { + animalByName(name: $name) { + __typename + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } + } + """ + ) + variables: Dict[str, object] = {"name": name} + response = await self.execute( + query=query, operation_name="GetAnimalByName", variables=variables, **kwargs + ) + data = self.get_data(response) + return GetAnimalByName.model_validate(data) + + async def list_animals(self, **kwargs: Any) -> "ListAnimals": + from .list_animals import ListAnimals + + query = gql( + """ + query ListAnimals { + listAnimals { + __typename + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListAnimals", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListAnimals.model_validate(data) + + async def get_animal_fragment_with_extra( + self, **kwargs: Any + ) -> "GetAnimalFragmentWithExtra": + from .get_animal_fragment_with_extra import GetAnimalFragmentWithExtra + + query = gql( + """ + query GetAnimalFragmentWithExtra { + ...ListAnimalsFragment + listString + } + + fragment ListAnimalsFragment on Query { + listAnimals { + name + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, + operation_name="GetAnimalFragmentWithExtra", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return GetAnimalFragmentWithExtra.model_validate(data) + + async def get_simple_scalar(self, **kwargs: Any) -> "GetSimpleScalar": + from .get_simple_scalar import GetSimpleScalar + + query = gql( + """ + query GetSimpleScalar { + justSimpleScalar + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="GetSimpleScalar", variables=variables, **kwargs + ) + data = self.get_data(response) + return GetSimpleScalar.model_validate(data) + + async def get_complex_scalar(self, **kwargs: Any) -> "GetComplexScalar": + from .get_complex_scalar import GetComplexScalar + + query = gql( + """ + query GetComplexScalar { + justComplexScalar + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, + operation_name="GetComplexScalar", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return GetComplexScalar.model_validate(data) + + async def subscribe_strings( + self, **kwargs: Any + ) -> AsyncIterator["SubscribeStrings"]: + from .subscribe_strings import SubscribeStrings + + query = gql( + """ + subscription SubscribeStrings { + optionalListString + } + """ + ) + variables: Dict[str, object] = {} + async for data in self.execute_ws( + query=query, + operation_name="SubscribeStrings", + variables=variables, + **kwargs + ): + yield SubscribeStrings.model_validate(data) + + async def unwrap_fragment(self, **kwargs: Any) -> "UnwrapFragment": + from .unwrap_fragment import UnwrapFragment + + query = gql( + """ + query UnwrapFragment { + ...FragmentWithSingleField + } + + fragment FragmentWithSingleField on Query { + queryUnwrapFragment { + id + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="UnwrapFragment", variables=variables, **kwargs + ) + data = self.get_data(response) + return UnwrapFragment.model_validate(data) diff --git a/tests/main/clients/no_global_imports/expected_client/custom_scalars.py b/tests/main/clients/no_global_imports/expected_client/custom_scalars.py new file mode 100644 index 00000000..5feb063a --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/custom_scalars.py @@ -0,0 +1,14 @@ +SimpleScalar = str + + +class ComplexScalar: + def __init__(self, value: str) -> None: + self.value = value + + +def parse_complex_scalar(value: str) -> ComplexScalar: + return ComplexScalar(value) + + +def serialize_complex_scalar(value: ComplexScalar) -> str: + return value.value diff --git a/tests/main/clients/no_global_imports/expected_client/enums.py b/tests/main/clients/no_global_imports/expected_client/enums.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/main/clients/no_global_imports/expected_client/exceptions.py b/tests/main/clients/no_global_imports/expected_client/exceptions.py new file mode 100644 index 00000000..b34acfe1 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/exceptions.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + + +class GraphQLClientError(Exception): + """Base exception.""" + + +class GraphQLClientHttpError(GraphQLClientError): + def __init__(self, status_code: int, response: httpx.Response) -> None: + self.status_code = status_code + self.response = response + + def __str__(self) -> str: + return f"HTTP status code: {self.status_code}" + + +class GraphQLClientInvalidResponseError(GraphQLClientError): + def __init__(self, response: httpx.Response) -> None: + self.response = response + + def __str__(self) -> str: + return "Invalid response format." + + +class GraphQLClientGraphQLError(GraphQLClientError): + def __init__( + self, + message: str, + locations: Optional[List[Dict[str, int]]] = None, + path: Optional[List[str]] = None, + extensions: Optional[Dict[str, object]] = None, + orginal: Optional[Dict[str, object]] = None, + ): + self.message = message + self.locations = locations + self.path = path + self.extensions = extensions + self.orginal = orginal + + def __str__(self) -> str: + return self.message + + @classmethod + def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": + return cls( + message=error["message"], + locations=error.get("locations"), + path=error.get("path"), + extensions=error.get("extensions"), + orginal=error, + ) + + +class GraphQLClientGraphQLMultiError(GraphQLClientError): + def __init__( + self, + errors: List[GraphQLClientGraphQLError], + data: Optional[Dict[str, Any]] = None, + ): + self.errors = errors + self.data = data + + def __str__(self) -> str: + return "; ".join(str(e) for e in self.errors) + + @classmethod + def from_errors_dicts( + cls, errors_dicts: List[Dict[str, Any]], data: Optional[Dict[str, Any]] = None + ) -> "GraphQLClientGraphQLMultiError": + return cls( + errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts], + data=data, + ) + + +class GraphQLClientInvalidMessageFormat(GraphQLClientError): + def __init__(self, message: Union[str, bytes]) -> None: + self.message = message + + def __str__(self) -> str: + return "Invalid message format." diff --git a/tests/main/clients/no_global_imports/expected_client/get_animal_by_name.py b/tests/main/clients/no_global_imports/expected_client/get_animal_by_name.py new file mode 100644 index 00000000..e97e12ac --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/get_animal_by_name.py @@ -0,0 +1,33 @@ +from typing import Literal, Union + +from pydantic import Field + +from .base_model import BaseModel + + +class GetAnimalByName(BaseModel): + animal_by_name: Union[ + "GetAnimalByNameAnimalByNameAnimal", + "GetAnimalByNameAnimalByNameCat", + "GetAnimalByNameAnimalByNameDog", + ] = Field(alias="animalByName", discriminator="typename__") + + +class GetAnimalByNameAnimalByNameAnimal(BaseModel): + typename__: Literal["Animal"] = Field(alias="__typename") + name: str + + +class GetAnimalByNameAnimalByNameCat(BaseModel): + typename__: Literal["Cat"] = Field(alias="__typename") + name: str + kittens: int + + +class GetAnimalByNameAnimalByNameDog(BaseModel): + typename__: Literal["Dog"] = Field(alias="__typename") + name: str + puppies: int + + +GetAnimalByName.model_rebuild() diff --git a/tests/main/clients/no_global_imports/expected_client/get_animal_fragment_with_extra.py b/tests/main/clients/no_global_imports/expected_client/get_animal_fragment_with_extra.py new file mode 100644 index 00000000..d7d15914 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/get_animal_fragment_with_extra.py @@ -0,0 +1,9 @@ +from typing import List + +from pydantic import Field + +from .no_global_imports_fragments import ListAnimalsFragment + + +class GetAnimalFragmentWithExtra(ListAnimalsFragment): + list_string: List[str] = Field(alias="listString") diff --git a/tests/main/clients/no_global_imports/expected_client/get_authenticated_user.py b/tests/main/clients/no_global_imports/expected_client/get_authenticated_user.py new file mode 100644 index 00000000..ce0e2cf7 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/get_authenticated_user.py @@ -0,0 +1,13 @@ +from .base_model import BaseModel + + +class GetAuthenticatedUser(BaseModel): + me: "GetAuthenticatedUserMe" + + +class GetAuthenticatedUserMe(BaseModel): + id: str + username: str + + +GetAuthenticatedUser.model_rebuild() diff --git a/tests/main/clients/no_global_imports/expected_client/get_complex_scalar.py b/tests/main/clients/no_global_imports/expected_client/get_complex_scalar.py new file mode 100644 index 00000000..78563772 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/get_complex_scalar.py @@ -0,0 +1,12 @@ +from typing import Annotated + +from pydantic import BeforeValidator, Field + +from .base_model import BaseModel +from .custom_scalars import ComplexScalar, parse_complex_scalar + + +class GetComplexScalar(BaseModel): + just_complex_scalar: Annotated[ + ComplexScalar, BeforeValidator(parse_complex_scalar) + ] = Field(alias="justComplexScalar") diff --git a/tests/main/clients/no_global_imports/expected_client/get_simple_scalar.py b/tests/main/clients/no_global_imports/expected_client/get_simple_scalar.py new file mode 100644 index 00000000..365f7eaa --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/get_simple_scalar.py @@ -0,0 +1,8 @@ +from pydantic import Field + +from .base_model import BaseModel +from .custom_scalars import SimpleScalar + + +class GetSimpleScalar(BaseModel): + just_simple_scalar: SimpleScalar = Field(alias="justSimpleScalar") diff --git a/tests/main/clients/no_global_imports/expected_client/input_types.py b/tests/main/clients/no_global_imports/expected_client/input_types.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/main/clients/no_global_imports/expected_client/list_animals.py b/tests/main/clients/no_global_imports/expected_client/list_animals.py new file mode 100644 index 00000000..bcab1c45 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/list_animals.py @@ -0,0 +1,38 @@ +from typing import Annotated, List, Literal, Union + +from pydantic import Field + +from .base_model import BaseModel + + +class ListAnimals(BaseModel): + list_animals: List[ + Annotated[ + Union[ + "ListAnimalsListAnimalsAnimal", + "ListAnimalsListAnimalsCat", + "ListAnimalsListAnimalsDog", + ], + Field(discriminator="typename__"), + ] + ] = Field(alias="listAnimals") + + +class ListAnimalsListAnimalsAnimal(BaseModel): + typename__: Literal["Animal"] = Field(alias="__typename") + name: str + + +class ListAnimalsListAnimalsCat(BaseModel): + typename__: Literal["Cat"] = Field(alias="__typename") + name: str + kittens: int + + +class ListAnimalsListAnimalsDog(BaseModel): + typename__: Literal["Dog"] = Field(alias="__typename") + name: str + puppies: int + + +ListAnimals.model_rebuild() diff --git a/tests/main/clients/no_global_imports/expected_client/list_strings_1.py b/tests/main/clients/no_global_imports/expected_client/list_strings_1.py new file mode 100644 index 00000000..fd8c06de --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/list_strings_1.py @@ -0,0 +1,11 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings1(BaseModel): + optional_list_optional_string: Optional[List[Optional[str]]] = Field( + alias="optionalListOptionalString" + ) diff --git a/tests/main/clients/no_global_imports/expected_client/list_strings_2.py b/tests/main/clients/no_global_imports/expected_client/list_strings_2.py new file mode 100644 index 00000000..d91ec117 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/list_strings_2.py @@ -0,0 +1,9 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings2(BaseModel): + optional_list_string: Optional[List[str]] = Field(alias="optionalListString") diff --git a/tests/main/clients/no_global_imports/expected_client/list_strings_3.py b/tests/main/clients/no_global_imports/expected_client/list_strings_3.py new file mode 100644 index 00000000..88f6e2cf --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/list_strings_3.py @@ -0,0 +1,9 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings3(BaseModel): + list_optional_string: List[Optional[str]] = Field(alias="listOptionalString") diff --git a/tests/main/clients/no_global_imports/expected_client/list_strings_4.py b/tests/main/clients/no_global_imports/expected_client/list_strings_4.py new file mode 100644 index 00000000..15872b23 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/list_strings_4.py @@ -0,0 +1,9 @@ +from typing import List + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings4(BaseModel): + list_string: List[str] = Field(alias="listString") diff --git a/tests/main/clients/no_global_imports/expected_client/list_type_a.py b/tests/main/clients/no_global_imports/expected_client/list_type_a.py new file mode 100644 index 00000000..9e2c8f04 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/list_type_a.py @@ -0,0 +1,18 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListTypeA(BaseModel): + list_optional_type_a: List[Optional["ListTypeAListOptionalTypeA"]] = Field( + alias="listOptionalTypeA" + ) + + +class ListTypeAListOptionalTypeA(BaseModel): + id: int + + +ListTypeA.model_rebuild() diff --git a/tests/main/clients/no_global_imports/expected_client/no_global_imports_fragments.py b/tests/main/clients/no_global_imports/expected_client/no_global_imports_fragments.py new file mode 100644 index 00000000..4b103a44 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/no_global_imports_fragments.py @@ -0,0 +1,28 @@ +from typing import List, Literal + +from pydantic import Field + +from .base_model import BaseModel + + +class FragmentWithSingleField(BaseModel): + query_unwrap_fragment: "FragmentWithSingleFieldQueryUnwrapFragment" = Field( + alias="queryUnwrapFragment" + ) + + +class FragmentWithSingleFieldQueryUnwrapFragment(BaseModel): + id: int + + +class ListAnimalsFragment(BaseModel): + list_animals: List["ListAnimalsFragmentListAnimals"] = Field(alias="listAnimals") + + +class ListAnimalsFragmentListAnimals(BaseModel): + typename__: Literal["Animal", "Cat", "Dog"] = Field(alias="__typename") + name: str + + +FragmentWithSingleField.model_rebuild() +ListAnimalsFragment.model_rebuild() diff --git a/tests/main/clients/no_global_imports/expected_client/subscribe_strings.py b/tests/main/clients/no_global_imports/expected_client/subscribe_strings.py new file mode 100644 index 00000000..a1a42d4f --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/subscribe_strings.py @@ -0,0 +1,9 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class SubscribeStrings(BaseModel): + optional_list_string: Optional[List[str]] = Field(alias="optionalListString") diff --git a/tests/main/clients/no_global_imports/expected_client/unwrap_fragment.py b/tests/main/clients/no_global_imports/expected_client/unwrap_fragment.py new file mode 100644 index 00000000..ad17ac32 --- /dev/null +++ b/tests/main/clients/no_global_imports/expected_client/unwrap_fragment.py @@ -0,0 +1,5 @@ +from .no_global_imports_fragments import FragmentWithSingleField + + +class UnwrapFragment(FragmentWithSingleField): + pass diff --git a/tests/main/clients/no_global_imports/pyproject.toml b/tests/main/clients/no_global_imports/pyproject.toml new file mode 100644 index 00000000..f3dc5453 --- /dev/null +++ b/tests/main/clients/no_global_imports/pyproject.toml @@ -0,0 +1,16 @@ +[tool.ariadne-codegen] +schema_path = "schema.graphql" +queries_path = "queries.graphql" +include_comments = "none" +target_package_name = "no_global_imports" +files_to_include = ["custom_scalars.py"] +fragments_module_name = "no_global_imports_fragments" +plugins = ["ariadne_codegen.contrib.no_global_imports.NoGlobalImportsPlugin"] + +[tool.ariadne-codegen.scalars.SimpleScalar] +type = ".custom_scalars.SimpleScalar" + +[tool.ariadne-codegen.scalars.ComplexScalar] +type = ".custom_scalars.ComplexScalar" +parse = ".custom_scalars.parse_complex_scalar" +serialize = ".custom_scalars.serialize_complex_scalar" diff --git a/tests/main/clients/no_global_imports/queries.graphql b/tests/main/clients/no_global_imports/queries.graphql new file mode 100644 index 00000000..d407732c --- /dev/null +++ b/tests/main/clients/no_global_imports/queries.graphql @@ -0,0 +1,86 @@ +query GetAuthenticatedUser { + me { + id + username + } +} + +query ListStrings_1 { + optionalListOptionalString +} + +query ListStrings_2 { + optionalListString +} + +query ListStrings_3 { + listOptionalString +} + +query ListStrings_4 { + listString +} + +query ListTypeA { + listOptionalTypeA { + id + } +} + +query GetAnimalByName($name: String!) { + animalByName(name: $name) { + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } +} + +query ListAnimals { + listAnimals { + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } +} + +query GetAnimalFragmentWithExtra { + ...ListAnimalsFragment + listString +} + +query GetSimpleScalar { + justSimpleScalar +} + +query GetComplexScalar { + justComplexScalar +} + +fragment ListAnimalsFragment on Query { + listAnimals { + name + } +} + +subscription SubscribeStrings { + optionalListString +} + + +fragment FragmentWithSingleField on Query { + queryUnwrapFragment { + id + } +} + +query UnwrapFragment { + ...FragmentWithSingleField +} diff --git a/tests/main/clients/no_global_imports/schema.graphql b/tests/main/clients/no_global_imports/schema.graphql new file mode 100644 index 00000000..7acc7f92 --- /dev/null +++ b/tests/main/clients/no_global_imports/schema.graphql @@ -0,0 +1,48 @@ +scalar SimpleScalar +scalar ComplexScalar + +type TypeA { + id: Int! +} + +type User { + id: ID! + username: String! +} + +interface Animal { + name: String! +} + +type Cat implements Animal { + name: String! + kittens: Int! +} + +type Dog implements Animal { + name: String! + puppies: Int! +} + +type Query { + optionalListOptionalString: [String] + optionalListString: [String!] + listOptionalString: [String]! + listString: [String!]! + + listOptionalTypeA: [TypeA]! + + me: User! + + listAnimals: [Animal!]! + animalByName(name: String!): Animal! + + justSimpleScalar: SimpleScalar! + justComplexScalar: ComplexScalar! + + queryUnwrapFragment: TypeA! +} + +type Subscription { + optionalListString: [String!] +} diff --git a/tests/main/clients/no_global_imports_shorter_results/custom_scalars.py b/tests/main/clients/no_global_imports_shorter_results/custom_scalars.py new file mode 100644 index 00000000..5feb063a --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/custom_scalars.py @@ -0,0 +1,14 @@ +SimpleScalar = str + + +class ComplexScalar: + def __init__(self, value: str) -> None: + self.value = value + + +def parse_complex_scalar(value: str) -> ComplexScalar: + return ComplexScalar(value) + + +def serialize_complex_scalar(value: ComplexScalar) -> str: + return value.value diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/__init__.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/__init__.py new file mode 100644 index 00000000..6c15fb9c --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/__init__.py @@ -0,0 +1,76 @@ +from .async_base_client import AsyncBaseClient +from .base_model import BaseModel, Upload +from .client import Client +from .exceptions import ( + GraphQLClientError, + GraphQLClientGraphQLError, + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidResponseError, +) +from .get_animal_by_name import ( + GetAnimalByName, + GetAnimalByNameAnimalByNameAnimal, + GetAnimalByNameAnimalByNameCat, + GetAnimalByNameAnimalByNameDog, +) +from .get_animal_fragment_with_extra import GetAnimalFragmentWithExtra +from .get_authenticated_user import GetAuthenticatedUser, GetAuthenticatedUserMe +from .get_complex_scalar import GetComplexScalar +from .get_simple_scalar import GetSimpleScalar +from .list_animals import ( + ListAnimals, + ListAnimalsListAnimalsAnimal, + ListAnimalsListAnimalsCat, + ListAnimalsListAnimalsDog, +) +from .list_strings_1 import ListStrings1 +from .list_strings_2 import ListStrings2 +from .list_strings_3 import ListStrings3 +from .list_strings_4 import ListStrings4 +from .list_type_a import ListTypeA, ListTypeAListOptionalTypeA +from .no_global_imports_shorter_resultsfragments import ( + FragmentWithSingleField, + FragmentWithSingleFieldQueryUnwrapFragment, + ListAnimalsFragment, + ListAnimalsFragmentListAnimals, +) +from .subscribe_strings import SubscribeStrings +from .unwrap_fragment import UnwrapFragment + +__all__ = [ + "AsyncBaseClient", + "BaseModel", + "Client", + "FragmentWithSingleField", + "FragmentWithSingleFieldQueryUnwrapFragment", + "GetAnimalByName", + "GetAnimalByNameAnimalByNameAnimal", + "GetAnimalByNameAnimalByNameCat", + "GetAnimalByNameAnimalByNameDog", + "GetAnimalFragmentWithExtra", + "GetAuthenticatedUser", + "GetAuthenticatedUserMe", + "GetComplexScalar", + "GetSimpleScalar", + "GraphQLClientError", + "GraphQLClientGraphQLError", + "GraphQLClientGraphQLMultiError", + "GraphQLClientHttpError", + "GraphQLClientInvalidResponseError", + "ListAnimals", + "ListAnimalsFragment", + "ListAnimalsFragmentListAnimals", + "ListAnimalsListAnimalsAnimal", + "ListAnimalsListAnimalsCat", + "ListAnimalsListAnimalsDog", + "ListStrings1", + "ListStrings2", + "ListStrings3", + "ListStrings4", + "ListTypeA", + "ListTypeAListOptionalTypeA", + "SubscribeStrings", + "UnwrapFragment", + "Upload", +] diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/async_base_client.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/async_base_client.py new file mode 100644 index 00000000..5358ced6 --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/async_base_client.py @@ -0,0 +1,370 @@ +import enum +import json +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from uuid import uuid4 + +import httpx +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidMessageFormat, + GraphQLClientInvalidResponseError, +) + +try: + from websockets.client import ( # type: ignore[import-not-found,unused-ignore] + WebSocketClientProtocol, + connect as ws_connect, + ) + from websockets.typing import ( # type: ignore[import-not-found,unused-ignore] + Data, + Origin, + Subprotocol, + ) +except ImportError: + from contextlib import asynccontextmanager + + @asynccontextmanager # type: ignore + async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument + raise NotImplementedError("Subscriptions require 'websockets' package.") + yield # pylint: disable=unreachable + + WebSocketClientProtocol = Any # type: ignore[misc,assignment,unused-ignore] + Data = Any # type: ignore[misc,assignment,unused-ignore] + Origin = Any # type: ignore[misc,assignment,unused-ignore] + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") + + +Self = TypeVar("Self", bound="AsyncBaseClient") + +GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" + + +class GraphQLTransportWSMessageType(str, enum.Enum): + CONNECTION_INIT = "connection_init" + CONNECTION_ACK = "connection_ack" + PING = "ping" + PONG = "pong" + SUBSCRIBE = "subscribe" + NEXT = "next" + ERROR = "error" + COMPLETE = "complete" + + +class AsyncBaseClient: + def __init__( + self, + url: str = "", + headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.AsyncClient] = None, + ws_url: str = "", + ws_headers: Optional[Dict[str, Any]] = None, + ws_origin: Optional[str] = None, + ws_connection_init_payload: Optional[Dict[str, Any]] = None, + ) -> None: + self.url = url + self.headers = headers + self.http_client = ( + http_client if http_client else httpx.AsyncClient(headers=headers) + ) + + self.ws_url = ws_url + self.ws_headers = ws_headers or {} + self.ws_origin = Origin(ws_origin) if ws_origin else None + self.ws_connection_init_payload = ws_connection_init_payload + + async def __aenter__(self: Self) -> Self: + return self + + async def __aexit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + await self.http_client.aclose() + + async def execute( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + operation_name=operation_name, + variables=processed_variables, + files=files, + files_map=files_map, + **kwargs, + ) + + return await self._execute_json( + query=query, + operation_name=operation_name, + variables=processed_variables, + **kwargs, + ) + + def get_data(self, response: httpx.Response) -> Dict[str, Any]: + if not response.is_success: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQLClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ( + "data" not in response_json and "errors" not in response_json + ): + raise GraphQLClientInvalidResponseError(response=response) + + data = response_json.get("data") + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(Dict[str, Any], data) + + async def execute_ws( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + **merged_kwargs, + ) as websocket: + await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + operation_name=operation_name, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + async def _execute_multipart( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + **kwargs: Any, + ) -> httpx.Response: + data = { + "operations": json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) + + async def _execute_json( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + **kwargs: Any, + ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + + return await self.http_client.post( + url=self.url, + content=json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + **merged_kwargs, + ) + + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: + payload: Dict[str, Any] = { + "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value + } + if self.ws_connection_init_payload: + payload["payload"] = self.ws_connection_init_payload + await websocket.send(json.dumps(payload)) + + async def _send_subscribe( + self, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + payload: Dict[str, Any] = { + "id": operation_id, + "type": GraphQLTransportWSMessageType.SUBSCRIBE.value, + "payload": {"query": query, "operationName": operation_name}, + } + if variables: + payload["payload"]["variables"] = self._convert_dict_to_json_serializable( + variables + ) + await websocket.send(json.dumps(payload)) + + async def _handle_ws_message( + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, + ) -> Optional[Dict[str, Any]]: + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: + raise GraphQLClientInvalidMessageFormat(message=message) + + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/base_model.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/base_model.py new file mode 100644 index 00000000..ccde3975 --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/base_model.py @@ -0,0 +1,27 @@ +from io import IOBase + +from pydantic import BaseModel as PydanticBaseModel, ConfigDict + + +class UnsetType: + def __bool__(self) -> bool: + return False + + +UNSET = UnsetType() + + +class BaseModel(PydanticBaseModel): + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + arbitrary_types_allowed=True, + protected_namespaces=(), + ) + + +class Upload: + def __init__(self, filename: str, content: IOBase, content_type: str): + self.filename = filename + self.content = content + self.content_type = content_type diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/client.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/client.py new file mode 100644 index 00000000..655dbb50 --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/client.py @@ -0,0 +1,314 @@ +from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union + +from .async_base_client import AsyncBaseClient + +if TYPE_CHECKING: + from .custom_scalars import ComplexScalar, SimpleScalar + from .get_animal_by_name import ( + GetAnimalByNameAnimalByNameAnimal, + GetAnimalByNameAnimalByNameCat, + GetAnimalByNameAnimalByNameDog, + ) + from .get_animal_fragment_with_extra import GetAnimalFragmentWithExtra + from .get_authenticated_user import GetAuthenticatedUserMe + from .list_animals import ( + ListAnimalsListAnimalsAnimal, + ListAnimalsListAnimalsCat, + ListAnimalsListAnimalsDog, + ) + from .list_type_a import ListTypeAListOptionalTypeA + from .no_global_imports_shorter_resultsfragments import ( + FragmentWithSingleFieldQueryUnwrapFragment, + ) + + +def gql(q: str) -> str: + return q + + +class Client(AsyncBaseClient): + async def get_authenticated_user(self, **kwargs: Any) -> "GetAuthenticatedUserMe": + from .get_authenticated_user import GetAuthenticatedUser + + query = gql( + """ + query GetAuthenticatedUser { + me { + id + username + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, + operation_name="GetAuthenticatedUser", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return GetAuthenticatedUser.model_validate(data).me + + async def list_strings_1(self, **kwargs: Any) -> Optional[List[Optional[str]]]: + from .list_strings_1 import ListStrings1 + + query = gql( + """ + query ListStrings_1 { + optionalListOptionalString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_1", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings1.model_validate(data).optional_list_optional_string + + async def list_strings_2(self, **kwargs: Any) -> Optional[List[str]]: + from .list_strings_2 import ListStrings2 + + query = gql( + """ + query ListStrings_2 { + optionalListString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_2", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings2.model_validate(data).optional_list_string + + async def list_strings_3(self, **kwargs: Any) -> List[Optional[str]]: + from .list_strings_3 import ListStrings3 + + query = gql( + """ + query ListStrings_3 { + listOptionalString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_3", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings3.model_validate(data).list_optional_string + + async def list_strings_4(self, **kwargs: Any) -> List[str]: + from .list_strings_4 import ListStrings4 + + query = gql( + """ + query ListStrings_4 { + listString + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListStrings_4", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListStrings4.model_validate(data).list_string + + async def list_type_a( + self, **kwargs: Any + ) -> List[Optional["ListTypeAListOptionalTypeA"]]: + from .list_type_a import ListTypeA + + query = gql( + """ + query ListTypeA { + listOptionalTypeA { + id + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListTypeA", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListTypeA.model_validate(data).list_optional_type_a + + async def get_animal_by_name(self, name: str, **kwargs: Any) -> Union[ + "GetAnimalByNameAnimalByNameAnimal", + "GetAnimalByNameAnimalByNameCat", + "GetAnimalByNameAnimalByNameDog", + ]: + from .get_animal_by_name import GetAnimalByName + + query = gql( + """ + query GetAnimalByName($name: String!) { + animalByName(name: $name) { + __typename + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } + } + """ + ) + variables: Dict[str, object] = {"name": name} + response = await self.execute( + query=query, operation_name="GetAnimalByName", variables=variables, **kwargs + ) + data = self.get_data(response) + return GetAnimalByName.model_validate(data).animal_by_name + + async def list_animals(self, **kwargs: Any) -> List[ + Union[ + "ListAnimalsListAnimalsAnimal", + "ListAnimalsListAnimalsCat", + "ListAnimalsListAnimalsDog", + ] + ]: + from .list_animals import ListAnimals + + query = gql( + """ + query ListAnimals { + listAnimals { + __typename + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="ListAnimals", variables=variables, **kwargs + ) + data = self.get_data(response) + return ListAnimals.model_validate(data).list_animals + + async def get_animal_fragment_with_extra( + self, **kwargs: Any + ) -> "GetAnimalFragmentWithExtra": + from .get_animal_fragment_with_extra import GetAnimalFragmentWithExtra + + query = gql( + """ + query GetAnimalFragmentWithExtra { + ...ListAnimalsFragment + listString + } + + fragment ListAnimalsFragment on Query { + listAnimals { + name + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, + operation_name="GetAnimalFragmentWithExtra", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return GetAnimalFragmentWithExtra.model_validate(data) + + async def get_simple_scalar(self, **kwargs: Any) -> "SimpleScalar": + from .get_simple_scalar import GetSimpleScalar + + query = gql( + """ + query GetSimpleScalar { + justSimpleScalar + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="GetSimpleScalar", variables=variables, **kwargs + ) + data = self.get_data(response) + return GetSimpleScalar.model_validate(data).just_simple_scalar + + async def get_complex_scalar(self, **kwargs: Any) -> "ComplexScalar": + from .get_complex_scalar import GetComplexScalar + + query = gql( + """ + query GetComplexScalar { + justComplexScalar + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, + operation_name="GetComplexScalar", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return GetComplexScalar.model_validate(data).just_complex_scalar + + async def subscribe_strings( + self, **kwargs: Any + ) -> AsyncIterator[Optional[List[str]]]: + from .subscribe_strings import SubscribeStrings + + query = gql( + """ + subscription SubscribeStrings { + optionalListString + } + """ + ) + variables: Dict[str, object] = {} + async for data in self.execute_ws( + query=query, + operation_name="SubscribeStrings", + variables=variables, + **kwargs + ): + yield SubscribeStrings.model_validate(data).optional_list_string + + async def unwrap_fragment( + self, **kwargs: Any + ) -> "FragmentWithSingleFieldQueryUnwrapFragment": + from .unwrap_fragment import UnwrapFragment + + query = gql( + """ + query UnwrapFragment { + ...FragmentWithSingleField + } + + fragment FragmentWithSingleField on Query { + queryUnwrapFragment { + id + } + } + """ + ) + variables: Dict[str, object] = {} + response = await self.execute( + query=query, operation_name="UnwrapFragment", variables=variables, **kwargs + ) + data = self.get_data(response) + return UnwrapFragment.model_validate(data).query_unwrap_fragment diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/custom_scalars.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/custom_scalars.py new file mode 100644 index 00000000..5feb063a --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/custom_scalars.py @@ -0,0 +1,14 @@ +SimpleScalar = str + + +class ComplexScalar: + def __init__(self, value: str) -> None: + self.value = value + + +def parse_complex_scalar(value: str) -> ComplexScalar: + return ComplexScalar(value) + + +def serialize_complex_scalar(value: ComplexScalar) -> str: + return value.value diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/enums.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/enums.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/exceptions.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/exceptions.py new file mode 100644 index 00000000..b34acfe1 --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/exceptions.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + + +class GraphQLClientError(Exception): + """Base exception.""" + + +class GraphQLClientHttpError(GraphQLClientError): + def __init__(self, status_code: int, response: httpx.Response) -> None: + self.status_code = status_code + self.response = response + + def __str__(self) -> str: + return f"HTTP status code: {self.status_code}" + + +class GraphQLClientInvalidResponseError(GraphQLClientError): + def __init__(self, response: httpx.Response) -> None: + self.response = response + + def __str__(self) -> str: + return "Invalid response format." + + +class GraphQLClientGraphQLError(GraphQLClientError): + def __init__( + self, + message: str, + locations: Optional[List[Dict[str, int]]] = None, + path: Optional[List[str]] = None, + extensions: Optional[Dict[str, object]] = None, + orginal: Optional[Dict[str, object]] = None, + ): + self.message = message + self.locations = locations + self.path = path + self.extensions = extensions + self.orginal = orginal + + def __str__(self) -> str: + return self.message + + @classmethod + def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": + return cls( + message=error["message"], + locations=error.get("locations"), + path=error.get("path"), + extensions=error.get("extensions"), + orginal=error, + ) + + +class GraphQLClientGraphQLMultiError(GraphQLClientError): + def __init__( + self, + errors: List[GraphQLClientGraphQLError], + data: Optional[Dict[str, Any]] = None, + ): + self.errors = errors + self.data = data + + def __str__(self) -> str: + return "; ".join(str(e) for e in self.errors) + + @classmethod + def from_errors_dicts( + cls, errors_dicts: List[Dict[str, Any]], data: Optional[Dict[str, Any]] = None + ) -> "GraphQLClientGraphQLMultiError": + return cls( + errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts], + data=data, + ) + + +class GraphQLClientInvalidMessageFormat(GraphQLClientError): + def __init__(self, message: Union[str, bytes]) -> None: + self.message = message + + def __str__(self) -> str: + return "Invalid message format." diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/get_animal_by_name.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/get_animal_by_name.py new file mode 100644 index 00000000..e97e12ac --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/get_animal_by_name.py @@ -0,0 +1,33 @@ +from typing import Literal, Union + +from pydantic import Field + +from .base_model import BaseModel + + +class GetAnimalByName(BaseModel): + animal_by_name: Union[ + "GetAnimalByNameAnimalByNameAnimal", + "GetAnimalByNameAnimalByNameCat", + "GetAnimalByNameAnimalByNameDog", + ] = Field(alias="animalByName", discriminator="typename__") + + +class GetAnimalByNameAnimalByNameAnimal(BaseModel): + typename__: Literal["Animal"] = Field(alias="__typename") + name: str + + +class GetAnimalByNameAnimalByNameCat(BaseModel): + typename__: Literal["Cat"] = Field(alias="__typename") + name: str + kittens: int + + +class GetAnimalByNameAnimalByNameDog(BaseModel): + typename__: Literal["Dog"] = Field(alias="__typename") + name: str + puppies: int + + +GetAnimalByName.model_rebuild() diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/get_animal_fragment_with_extra.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/get_animal_fragment_with_extra.py new file mode 100644 index 00000000..90ffbdd1 --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/get_animal_fragment_with_extra.py @@ -0,0 +1,9 @@ +from typing import List + +from pydantic import Field + +from .no_global_imports_shorter_resultsfragments import ListAnimalsFragment + + +class GetAnimalFragmentWithExtra(ListAnimalsFragment): + list_string: List[str] = Field(alias="listString") diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/get_authenticated_user.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/get_authenticated_user.py new file mode 100644 index 00000000..ce0e2cf7 --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/get_authenticated_user.py @@ -0,0 +1,13 @@ +from .base_model import BaseModel + + +class GetAuthenticatedUser(BaseModel): + me: "GetAuthenticatedUserMe" + + +class GetAuthenticatedUserMe(BaseModel): + id: str + username: str + + +GetAuthenticatedUser.model_rebuild() diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/get_complex_scalar.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/get_complex_scalar.py new file mode 100644 index 00000000..78563772 --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/get_complex_scalar.py @@ -0,0 +1,12 @@ +from typing import Annotated + +from pydantic import BeforeValidator, Field + +from .base_model import BaseModel +from .custom_scalars import ComplexScalar, parse_complex_scalar + + +class GetComplexScalar(BaseModel): + just_complex_scalar: Annotated[ + ComplexScalar, BeforeValidator(parse_complex_scalar) + ] = Field(alias="justComplexScalar") diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/get_simple_scalar.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/get_simple_scalar.py new file mode 100644 index 00000000..365f7eaa --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/get_simple_scalar.py @@ -0,0 +1,8 @@ +from pydantic import Field + +from .base_model import BaseModel +from .custom_scalars import SimpleScalar + + +class GetSimpleScalar(BaseModel): + just_simple_scalar: SimpleScalar = Field(alias="justSimpleScalar") diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/input_types.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/input_types.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/list_animals.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/list_animals.py new file mode 100644 index 00000000..bcab1c45 --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/list_animals.py @@ -0,0 +1,38 @@ +from typing import Annotated, List, Literal, Union + +from pydantic import Field + +from .base_model import BaseModel + + +class ListAnimals(BaseModel): + list_animals: List[ + Annotated[ + Union[ + "ListAnimalsListAnimalsAnimal", + "ListAnimalsListAnimalsCat", + "ListAnimalsListAnimalsDog", + ], + Field(discriminator="typename__"), + ] + ] = Field(alias="listAnimals") + + +class ListAnimalsListAnimalsAnimal(BaseModel): + typename__: Literal["Animal"] = Field(alias="__typename") + name: str + + +class ListAnimalsListAnimalsCat(BaseModel): + typename__: Literal["Cat"] = Field(alias="__typename") + name: str + kittens: int + + +class ListAnimalsListAnimalsDog(BaseModel): + typename__: Literal["Dog"] = Field(alias="__typename") + name: str + puppies: int + + +ListAnimals.model_rebuild() diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_1.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_1.py new file mode 100644 index 00000000..fd8c06de --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_1.py @@ -0,0 +1,11 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings1(BaseModel): + optional_list_optional_string: Optional[List[Optional[str]]] = Field( + alias="optionalListOptionalString" + ) diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_2.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_2.py new file mode 100644 index 00000000..d91ec117 --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_2.py @@ -0,0 +1,9 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings2(BaseModel): + optional_list_string: Optional[List[str]] = Field(alias="optionalListString") diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_3.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_3.py new file mode 100644 index 00000000..88f6e2cf --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_3.py @@ -0,0 +1,9 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings3(BaseModel): + list_optional_string: List[Optional[str]] = Field(alias="listOptionalString") diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_4.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_4.py new file mode 100644 index 00000000..15872b23 --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_4.py @@ -0,0 +1,9 @@ +from typing import List + +from pydantic import Field + +from .base_model import BaseModel + + +class ListStrings4(BaseModel): + list_string: List[str] = Field(alias="listString") diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/list_type_a.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/list_type_a.py new file mode 100644 index 00000000..9e2c8f04 --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/list_type_a.py @@ -0,0 +1,18 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class ListTypeA(BaseModel): + list_optional_type_a: List[Optional["ListTypeAListOptionalTypeA"]] = Field( + alias="listOptionalTypeA" + ) + + +class ListTypeAListOptionalTypeA(BaseModel): + id: int + + +ListTypeA.model_rebuild() diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/no_global_imports_shorter_resultsfragments.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/no_global_imports_shorter_resultsfragments.py new file mode 100644 index 00000000..4b103a44 --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/no_global_imports_shorter_resultsfragments.py @@ -0,0 +1,28 @@ +from typing import List, Literal + +from pydantic import Field + +from .base_model import BaseModel + + +class FragmentWithSingleField(BaseModel): + query_unwrap_fragment: "FragmentWithSingleFieldQueryUnwrapFragment" = Field( + alias="queryUnwrapFragment" + ) + + +class FragmentWithSingleFieldQueryUnwrapFragment(BaseModel): + id: int + + +class ListAnimalsFragment(BaseModel): + list_animals: List["ListAnimalsFragmentListAnimals"] = Field(alias="listAnimals") + + +class ListAnimalsFragmentListAnimals(BaseModel): + typename__: Literal["Animal", "Cat", "Dog"] = Field(alias="__typename") + name: str + + +FragmentWithSingleField.model_rebuild() +ListAnimalsFragment.model_rebuild() diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/subscribe_strings.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/subscribe_strings.py new file mode 100644 index 00000000..a1a42d4f --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/subscribe_strings.py @@ -0,0 +1,9 @@ +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class SubscribeStrings(BaseModel): + optional_list_string: Optional[List[str]] = Field(alias="optionalListString") diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/unwrap_fragment.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/unwrap_fragment.py new file mode 100644 index 00000000..8189e91f --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/expected_client/unwrap_fragment.py @@ -0,0 +1,5 @@ +from .no_global_imports_shorter_resultsfragments import FragmentWithSingleField + + +class UnwrapFragment(FragmentWithSingleField): + pass diff --git a/tests/main/clients/no_global_imports_shorter_results/pyproject.toml b/tests/main/clients/no_global_imports_shorter_results/pyproject.toml new file mode 100644 index 00000000..9379191a --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/pyproject.toml @@ -0,0 +1,19 @@ +[tool.ariadne-codegen] +schema_path = "schema.graphql" +queries_path = "queries.graphql" +include_comments = "none" +target_package_name = "no_global_imports_shorter_results" +files_to_include = ["custom_scalars.py"] +fragments_module_name = "no_global_imports_shorter_resultsfragments" +plugins = [ + "ariadne_codegen.contrib.shorter_results.ShorterResultsPlugin", + "ariadne_codegen.contrib.no_global_imports.NoGlobalImportsPlugin" +] + +[tool.ariadne-codegen.scalars.SimpleScalar] +type = ".custom_scalars.SimpleScalar" + +[tool.ariadne-codegen.scalars.ComplexScalar] +type = ".custom_scalars.ComplexScalar" +parse = ".custom_scalars.parse_complex_scalar" +serialize = ".custom_scalars.serialize_complex_scalar" diff --git a/tests/main/clients/no_global_imports_shorter_results/queries.graphql b/tests/main/clients/no_global_imports_shorter_results/queries.graphql new file mode 100644 index 00000000..d407732c --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/queries.graphql @@ -0,0 +1,86 @@ +query GetAuthenticatedUser { + me { + id + username + } +} + +query ListStrings_1 { + optionalListOptionalString +} + +query ListStrings_2 { + optionalListString +} + +query ListStrings_3 { + listOptionalString +} + +query ListStrings_4 { + listString +} + +query ListTypeA { + listOptionalTypeA { + id + } +} + +query GetAnimalByName($name: String!) { + animalByName(name: $name) { + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } +} + +query ListAnimals { + listAnimals { + name + ... on Cat { + kittens + } + ... on Dog { + puppies + } + } +} + +query GetAnimalFragmentWithExtra { + ...ListAnimalsFragment + listString +} + +query GetSimpleScalar { + justSimpleScalar +} + +query GetComplexScalar { + justComplexScalar +} + +fragment ListAnimalsFragment on Query { + listAnimals { + name + } +} + +subscription SubscribeStrings { + optionalListString +} + + +fragment FragmentWithSingleField on Query { + queryUnwrapFragment { + id + } +} + +query UnwrapFragment { + ...FragmentWithSingleField +} diff --git a/tests/main/clients/no_global_imports_shorter_results/schema.graphql b/tests/main/clients/no_global_imports_shorter_results/schema.graphql new file mode 100644 index 00000000..7acc7f92 --- /dev/null +++ b/tests/main/clients/no_global_imports_shorter_results/schema.graphql @@ -0,0 +1,48 @@ +scalar SimpleScalar +scalar ComplexScalar + +type TypeA { + id: Int! +} + +type User { + id: ID! + username: String! +} + +interface Animal { + name: String! +} + +type Cat implements Animal { + name: String! + kittens: Int! +} + +type Dog implements Animal { + name: String! + puppies: Int! +} + +type Query { + optionalListOptionalString: [String] + optionalListString: [String!] + listOptionalString: [String]! + listString: [String!]! + + listOptionalTypeA: [TypeA]! + + me: User! + + listAnimals: [Animal!]! + animalByName(name: String!): Animal! + + justSimpleScalar: SimpleScalar! + justComplexScalar: ComplexScalar! + + queryUnwrapFragment: TypeA! +} + +type Subscription { + optionalListString: [String!] +} From 26d089bdc42db8eb9ec6d00bcdc01a5a723e762a Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Wed, 13 Mar 2024 22:30:40 +0100 Subject: [PATCH 02/11] Use constants for `TYPE_CHECKING` --- ariadne_codegen/contrib/no_global_imports.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ariadne_codegen/contrib/no_global_imports.py b/ariadne_codegen/contrib/no_global_imports.py index eb69e40e..f3fc08ec 100644 --- a/ariadne_codegen/contrib/no_global_imports.py +++ b/ariadne_codegen/contrib/no_global_imports.py @@ -15,6 +15,9 @@ from ariadne_codegen import Plugin +TYPE_CHECKING_MODULE: str = "typing" +TYPE_CHECKING_FLAG: str = "TYPE_CHECKING" + class NoGlobalImportsPlugin(Plugin): """Only import types when you call an endpoint needing it""" @@ -313,7 +316,7 @@ def _update_imports(self, module: ast.Module) -> ast.Name | None: type_checking_imports[module_name].names.append(ast.alias(cls)) import_if_type_checking = ast.If( - test=ast.Name(id="TYPE_CHECKING"), + test=ast.Name(id=TYPE_CHECKING_FLAG), body=list(type_checking_imports.values()), orelse=[], ) @@ -324,8 +327,8 @@ def _update_imports(self, module: ast.Module) -> ast.Name | None: module.body.insert( len(non_empty_imports), ast.ImportFrom( - module="typing", - names=[ast.Name("TYPE_CHECKING")], + module=TYPE_CHECKING_MODULE, + names=[ast.Name(TYPE_CHECKING_FLAG)], ), ) From b16070d9658f2163b266932f50735478cd71042c Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Thu, 14 Mar 2024 09:53:12 +0100 Subject: [PATCH 03/11] Use types from `typing` for hints, remove faulty return, update docs --- ariadne_codegen/contrib/no_global_imports.py | 35 +++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/ariadne_codegen/contrib/no_global_imports.py b/ariadne_codegen/contrib/no_global_imports.py index f3fc08ec..b4281161 100644 --- a/ariadne_codegen/contrib/no_global_imports.py +++ b/ariadne_codegen/contrib/no_global_imports.py @@ -1,15 +1,15 @@ """ -Plugin to only import types when you call methods +Plugin to only import types for GraphQL responses when you call methods. This will massively reduce import times for larger projects since you only have to load the input types when loading the client. -All result types that's used to process the server response will only be -imported when the method is called. +All input and return types that's used to process the server response will +only be imported when the method is called. """ import ast -from typing import Dict +from typing import Dict, List, Optional, Set, Union from graphql import GraphQLSchema @@ -27,15 +27,15 @@ def __init__(self, schema: GraphQLSchema, config_dict: Dict) -> None: # Types that should only be imported in a `TYPE_CHECKING` context. This # is all the types used as arguments to a method or as a return type, # i.e. for type checking. - self.input_and_return_types: set[str] = set() + self.input_and_return_types: Set[str] = set() # Imported classes are classes imported from local imports. We keep a # map between name and module so we know how to import them in each # method. - self.imported_classes: dict[str, str] = {} + self.imported_classes: Dict[str, str] = {} # Imported classes in each method definition. - self.imported_in_method: set[str] = set() + self.imported_in_method: Set[str] = set() super().__init__(schema, config_dict) @@ -55,6 +55,7 @@ def generate_client_module(self, module: ast.Module) -> ast.Module: types used as input types. :param module: The ast for the module + :returns: A modified `ast.Module` """ self._store_imported_classes(module.body) @@ -83,7 +84,7 @@ def generate_client_module(self, module: ast.Module) -> ast.Module: return super().generate_client_module(module) - def _store_imported_classes(self, module_body: list[ast.stmt]): + def _store_imported_classes(self, module_body: List[ast.stmt]): """Fetch and store imported classes. Grab all imported classes with level 1 or starting with `.` because @@ -110,8 +111,8 @@ def _store_imported_classes(self, module_body: list[ast.stmt]): self.imported_classes[name.name] = from_ def _rewrite_input_args_to_constants( - self, method_def: ast.FunctionDef | ast.AsyncFunctionDef - ) -> ast.FunctionDef | ast.AsyncFunctionDef: + self, method_def: Union[ast.FunctionDef, ast.AsyncFunctionDef] + ) -> Union[ast.FunctionDef, ast.AsyncFunctionDef]: """Rewrite the arguments to a method. For any `ast.Name` that requires an import convert it to an @@ -134,7 +135,7 @@ def _rewrite_input_args_to_constants( return method_def def _insert_import_statement_in_method( - self, method_def: ast.FunctionDef | ast.AsyncFunctionDef + self, method_def: Union[ast.FunctionDef, ast.AsyncFunctionDef] ): """Insert import statement in method. @@ -175,7 +176,7 @@ def _insert_import_statement_in_method( ), ) - def _get_call_arg_from_return(self, return_stmt: ast.Return) -> ast.Call | None: + def _get_call_arg_from_return(self, return_stmt: ast.Return) -> Optional[ast.Call]: """Get the class used in the return statement. :param return_stmt: The statement used for return @@ -193,7 +194,9 @@ def _get_call_arg_from_return(self, return_stmt: ast.Return) -> ast.Call | None: return None - def _get_call_arg_from_async_for(self, last_stmt: ast.AsyncFor) -> ast.Call | None: + def _get_call_arg_from_async_for( + self, last_stmt: ast.AsyncFor + ) -> Optional[ast.Call]: """Get the class used in the yield expression. :param last_stmt: The statement used in `ast.AsyncFor` @@ -224,7 +227,7 @@ def _get_call_arg_from_async_for(self, last_stmt: ast.AsyncFor) -> ast.Call | No return None - def _get_class_from_call(self, call: ast.Call) -> ast.Name | None: + def _get_class_from_call(self, call: ast.Call) -> Optional[ast.Name]: """Get the class from an `ast.Call`. :param call: The `ast.Call` arg @@ -238,7 +241,7 @@ def _get_class_from_call(self, call: ast.Call) -> ast.Name | None: return call.func.value - def _update_imports(self, module: ast.Module) -> ast.Name | None: + def _update_imports(self, module: ast.Module): """Update all imports. Iterate over all imports and remove the aliases that we use as input or @@ -279,7 +282,7 @@ def _update_imports(self, module: ast.Module) -> ast.Name | None: # `ImportFrom` that ends up without any names, the formatting will not # work! It will only remove the empty `import from` but not other unused # imports. - non_empty_imports: list[ast.Import | ast.ImportFrom] = [] + non_empty_imports: List[Union[ast.Import, ast.ImportFrom]] = [] last_import_at = 0 for i, node in enumerate(module.body): if isinstance(node, ast.Import): From 1013f892ce3d207ce50b0893bfa4b8985fd8f326 Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Fri, 15 Mar 2024 17:14:10 +0100 Subject: [PATCH 04/11] Re-export plugin in --- ariadne_codegen/contrib/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ariadne_codegen/contrib/__init__.py b/ariadne_codegen/contrib/__init__.py index 63e5565f..437689ac 100644 --- a/ariadne_codegen/contrib/__init__.py +++ b/ariadne_codegen/contrib/__init__.py @@ -1,5 +1,11 @@ from .extract_operations import ExtractOperationsPlugin +from .no_global_imports import NoGlobalImportsPlugin from .no_reimports import NoReimportsPlugin from .shorter_results import ShorterResultsPlugin -__all__ = ["ExtractOperationsPlugin", "NoReimportsPlugin", "ShorterResultsPlugin"] +__all__ = [ + "ExtractOperationsPlugin", + "NoReimportsPlugin", + "ShorterResultsPlugin", + "NoGlobalImportsPlugin", +] From 777b3a51fd5575de5a82f1501e93ea5d2e98e48e Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Mon, 25 Mar 2024 10:00:08 +0100 Subject: [PATCH 05/11] Update README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rafał Pitoń --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2c570295..09abdb50 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ Ariadne Codegen ships with optional plugins importable from the `ariadne_codegen - [`ariadne_codegen.contrib.extract_operations.ExtractOperationsPlugin`](ariadne_codegen/contrib/extract_operations.py) - This extracts query strings from generated client's methods into separate `operations.py` module. It also modifies the generated client to import these definitions. Generated module name can be customized by adding `operations_module_name="custom_name"` to the `[tool.ariadne-codegen.operations]` section in config. Eg.: -- [`ariadne_codegen.contrib.no_global_imports.NoGlobalImportsPlugin`](ariadne_codegen/contrib/no_global_imports.py) - This plugin processes generated client module and convert all input arguments and return types to strings. The types will be imported only for type checking. +- [`ariadne_codegen.contrib.no_global_imports.NoGlobalImportsPlugin`](ariadne_codegen/contrib/no_global_imports.py) - This plugin changes generated client module moving all Pydantic models imports under the `TYPE_CHECKING` condition, making them forward references. This greatly improves the import performance for `client` module. ```toml [tool.ariadne-codegen] From 2325e45930a13f2e338aa356e4fd616a69f7fd2c Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Mon, 25 Mar 2024 10:12:59 +0100 Subject: [PATCH 06/11] Rename plugin to `ClientForwardRefsPlugin` --- CHANGELOG.md | 2 +- README.md | 2 +- ariadne_codegen/contrib/__init__.py | 4 ++-- .../{no_global_imports.py => client_forward_refs.py} | 2 +- .../custom_scalars.py | 0 .../expected_client/__init__.py | 12 ++++++------ .../expected_client/async_base_client.py | 0 .../expected_client/base_model.py | 0 .../expected_client/client.py | 0 .../client_forward_refs_fragments.py} | 0 .../expected_client/custom_scalars.py | 0 .../expected_client/enums.py | 0 .../expected_client/exceptions.py | 0 .../expected_client/get_animal_by_name.py | 0 .../get_animal_fragment_with_extra.py | 2 +- .../expected_client/get_authenticated_user.py | 0 .../expected_client/get_complex_scalar.py | 0 .../expected_client/get_simple_scalar.py | 0 .../expected_client/input_types.py | 0 .../expected_client/list_animals.py | 0 .../expected_client/list_strings_1.py | 0 .../expected_client/list_strings_2.py | 0 .../expected_client/list_strings_3.py | 0 .../expected_client/list_strings_4.py | 0 .../expected_client/list_type_a.py | 0 .../expected_client/subscribe_strings.py | 0 .../expected_client/unwrap_fragment.py | 5 +++++ .../pyproject.toml | 6 +++--- .../queries.graphql | 0 .../schema.graphql | 0 .../custom_scalars.py | 0 .../expected_client/__init__.py | 12 ++++++------ .../expected_client/async_base_client.py | 0 .../expected_client/base_model.py | 0 .../expected_client/client.py | 6 +++--- .../client_forward_refs_shorter_resultsfragments.py} | 0 .../expected_client/custom_scalars.py | 0 .../expected_client/enums.py | 0 .../expected_client/exceptions.py | 0 .../expected_client/get_animal_by_name.py | 0 .../get_animal_fragment_with_extra.py | 2 +- .../expected_client/get_authenticated_user.py | 0 .../expected_client/get_complex_scalar.py | 0 .../expected_client/get_simple_scalar.py | 0 .../expected_client/input_types.py | 0 .../expected_client/list_animals.py | 0 .../expected_client/list_strings_1.py | 0 .../expected_client/list_strings_2.py | 0 .../expected_client/list_strings_3.py | 0 .../expected_client/list_strings_4.py | 0 .../expected_client/list_type_a.py | 0 .../expected_client/subscribe_strings.py | 0 .../expected_client/unwrap_fragment.py | 5 +++++ .../pyproject.toml | 6 +++--- .../queries.graphql | 0 .../schema.graphql | 0 .../expected_client/unwrap_fragment.py | 5 ----- .../expected_client/unwrap_fragment.py | 5 ----- 58 files changed, 38 insertions(+), 38 deletions(-) rename ariadne_codegen/contrib/{no_global_imports.py => client_forward_refs.py} (99%) rename tests/main/clients/{no_global_imports => client_forward_refs}/custom_scalars.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/__init__.py (98%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/async_base_client.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/base_model.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/client.py (100%) rename tests/main/clients/{no_global_imports/expected_client/no_global_imports_fragments.py => client_forward_refs/expected_client/client_forward_refs_fragments.py} (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/custom_scalars.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/enums.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/exceptions.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/get_animal_by_name.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/get_animal_fragment_with_extra.py (72%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/get_authenticated_user.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/get_complex_scalar.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/get_simple_scalar.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/input_types.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/list_animals.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/list_strings_1.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/list_strings_2.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/list_strings_3.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/list_strings_4.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/list_type_a.py (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/expected_client/subscribe_strings.py (100%) create mode 100644 tests/main/clients/client_forward_refs/expected_client/unwrap_fragment.py rename tests/main/clients/{no_global_imports => client_forward_refs}/pyproject.toml (69%) rename tests/main/clients/{no_global_imports => client_forward_refs}/queries.graphql (100%) rename tests/main/clients/{no_global_imports => client_forward_refs}/schema.graphql (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/custom_scalars.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/__init__.py (97%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/async_base_client.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/base_model.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/client.py (99%) rename tests/main/clients/{no_global_imports_shorter_results/expected_client/no_global_imports_shorter_resultsfragments.py => client_forward_refs_shorter_results/expected_client/client_forward_refs_shorter_resultsfragments.py} (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/custom_scalars.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/enums.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/exceptions.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/get_animal_by_name.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/get_animal_fragment_with_extra.py (67%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/get_authenticated_user.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/get_complex_scalar.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/get_simple_scalar.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/input_types.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/list_animals.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/list_strings_1.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/list_strings_2.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/list_strings_3.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/list_strings_4.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/list_type_a.py (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/expected_client/subscribe_strings.py (100%) create mode 100644 tests/main/clients/client_forward_refs_shorter_results/expected_client/unwrap_fragment.py rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/pyproject.toml (71%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/queries.graphql (100%) rename tests/main/clients/{no_global_imports_shorter_results => client_forward_refs_shorter_results}/schema.graphql (100%) delete mode 100644 tests/main/clients/no_global_imports/expected_client/unwrap_fragment.py delete mode 100644 tests/main/clients/no_global_imports_shorter_results/expected_client/unwrap_fragment.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b20daa1..c5ed968d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ## 0.14.0 (Unreleased) -- Added `NoGlobalImportsPlugin` to standard plugins. +- Added `ClientForwardRefsPlugin` to standard plugins. - Re-added `model_rebuild` calls for input types with forward references. diff --git a/README.md b/README.md index 09abdb50..79969b90 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ Ariadne Codegen ships with optional plugins importable from the `ariadne_codegen - [`ariadne_codegen.contrib.extract_operations.ExtractOperationsPlugin`](ariadne_codegen/contrib/extract_operations.py) - This extracts query strings from generated client's methods into separate `operations.py` module. It also modifies the generated client to import these definitions. Generated module name can be customized by adding `operations_module_name="custom_name"` to the `[tool.ariadne-codegen.operations]` section in config. Eg.: -- [`ariadne_codegen.contrib.no_global_imports.NoGlobalImportsPlugin`](ariadne_codegen/contrib/no_global_imports.py) - This plugin changes generated client module moving all Pydantic models imports under the `TYPE_CHECKING` condition, making them forward references. This greatly improves the import performance for `client` module. +- [`ariadne_codegen.contrib.client_forward_refs.ClientForwardRefsPlugin`](ariadne_codegen/contrib/client_forward_refs.py) - This plugin changes generated client module moving all Pydantic models imports under the `TYPE_CHECKING` condition, making them forward references. This greatly improves the import performance for `client` module. ```toml [tool.ariadne-codegen] diff --git a/ariadne_codegen/contrib/__init__.py b/ariadne_codegen/contrib/__init__.py index 437689ac..6b3d6b2a 100644 --- a/ariadne_codegen/contrib/__init__.py +++ b/ariadne_codegen/contrib/__init__.py @@ -1,5 +1,5 @@ +from .client_forward_refs import ClientForwardRefsPlugin from .extract_operations import ExtractOperationsPlugin -from .no_global_imports import NoGlobalImportsPlugin from .no_reimports import NoReimportsPlugin from .shorter_results import ShorterResultsPlugin @@ -7,5 +7,5 @@ "ExtractOperationsPlugin", "NoReimportsPlugin", "ShorterResultsPlugin", - "NoGlobalImportsPlugin", + "ClientForwardRefsPlugin", ] diff --git a/ariadne_codegen/contrib/no_global_imports.py b/ariadne_codegen/contrib/client_forward_refs.py similarity index 99% rename from ariadne_codegen/contrib/no_global_imports.py rename to ariadne_codegen/contrib/client_forward_refs.py index b4281161..f7855f0b 100644 --- a/ariadne_codegen/contrib/no_global_imports.py +++ b/ariadne_codegen/contrib/client_forward_refs.py @@ -19,7 +19,7 @@ TYPE_CHECKING_FLAG: str = "TYPE_CHECKING" -class NoGlobalImportsPlugin(Plugin): +class ClientForwardRefsPlugin(Plugin): """Only import types when you call an endpoint needing it""" def __init__(self, schema: GraphQLSchema, config_dict: Dict) -> None: diff --git a/tests/main/clients/no_global_imports/custom_scalars.py b/tests/main/clients/client_forward_refs/custom_scalars.py similarity index 100% rename from tests/main/clients/no_global_imports/custom_scalars.py rename to tests/main/clients/client_forward_refs/custom_scalars.py diff --git a/tests/main/clients/no_global_imports/expected_client/__init__.py b/tests/main/clients/client_forward_refs/expected_client/__init__.py similarity index 98% rename from tests/main/clients/no_global_imports/expected_client/__init__.py rename to tests/main/clients/client_forward_refs/expected_client/__init__.py index 41b2df92..d518989d 100644 --- a/tests/main/clients/no_global_imports/expected_client/__init__.py +++ b/tests/main/clients/client_forward_refs/expected_client/__init__.py @@ -1,6 +1,12 @@ from .async_base_client import AsyncBaseClient from .base_model import BaseModel, Upload from .client import Client +from .client_forward_refs_fragments import ( + FragmentWithSingleField, + FragmentWithSingleFieldQueryUnwrapFragment, + ListAnimalsFragment, + ListAnimalsFragmentListAnimals, +) from .exceptions import ( GraphQLClientError, GraphQLClientGraphQLError, @@ -29,12 +35,6 @@ from .list_strings_3 import ListStrings3 from .list_strings_4 import ListStrings4 from .list_type_a import ListTypeA, ListTypeAListOptionalTypeA -from .no_global_imports_fragments import ( - FragmentWithSingleField, - FragmentWithSingleFieldQueryUnwrapFragment, - ListAnimalsFragment, - ListAnimalsFragmentListAnimals, -) from .subscribe_strings import SubscribeStrings from .unwrap_fragment import UnwrapFragment diff --git a/tests/main/clients/no_global_imports/expected_client/async_base_client.py b/tests/main/clients/client_forward_refs/expected_client/async_base_client.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/async_base_client.py rename to tests/main/clients/client_forward_refs/expected_client/async_base_client.py diff --git a/tests/main/clients/no_global_imports/expected_client/base_model.py b/tests/main/clients/client_forward_refs/expected_client/base_model.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/base_model.py rename to tests/main/clients/client_forward_refs/expected_client/base_model.py diff --git a/tests/main/clients/no_global_imports/expected_client/client.py b/tests/main/clients/client_forward_refs/expected_client/client.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/client.py rename to tests/main/clients/client_forward_refs/expected_client/client.py diff --git a/tests/main/clients/no_global_imports/expected_client/no_global_imports_fragments.py b/tests/main/clients/client_forward_refs/expected_client/client_forward_refs_fragments.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/no_global_imports_fragments.py rename to tests/main/clients/client_forward_refs/expected_client/client_forward_refs_fragments.py diff --git a/tests/main/clients/no_global_imports/expected_client/custom_scalars.py b/tests/main/clients/client_forward_refs/expected_client/custom_scalars.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/custom_scalars.py rename to tests/main/clients/client_forward_refs/expected_client/custom_scalars.py diff --git a/tests/main/clients/no_global_imports/expected_client/enums.py b/tests/main/clients/client_forward_refs/expected_client/enums.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/enums.py rename to tests/main/clients/client_forward_refs/expected_client/enums.py diff --git a/tests/main/clients/no_global_imports/expected_client/exceptions.py b/tests/main/clients/client_forward_refs/expected_client/exceptions.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/exceptions.py rename to tests/main/clients/client_forward_refs/expected_client/exceptions.py diff --git a/tests/main/clients/no_global_imports/expected_client/get_animal_by_name.py b/tests/main/clients/client_forward_refs/expected_client/get_animal_by_name.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/get_animal_by_name.py rename to tests/main/clients/client_forward_refs/expected_client/get_animal_by_name.py diff --git a/tests/main/clients/no_global_imports/expected_client/get_animal_fragment_with_extra.py b/tests/main/clients/client_forward_refs/expected_client/get_animal_fragment_with_extra.py similarity index 72% rename from tests/main/clients/no_global_imports/expected_client/get_animal_fragment_with_extra.py rename to tests/main/clients/client_forward_refs/expected_client/get_animal_fragment_with_extra.py index d7d15914..bf363706 100644 --- a/tests/main/clients/no_global_imports/expected_client/get_animal_fragment_with_extra.py +++ b/tests/main/clients/client_forward_refs/expected_client/get_animal_fragment_with_extra.py @@ -2,7 +2,7 @@ from pydantic import Field -from .no_global_imports_fragments import ListAnimalsFragment +from .client_forward_refs_fragments import ListAnimalsFragment class GetAnimalFragmentWithExtra(ListAnimalsFragment): diff --git a/tests/main/clients/no_global_imports/expected_client/get_authenticated_user.py b/tests/main/clients/client_forward_refs/expected_client/get_authenticated_user.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/get_authenticated_user.py rename to tests/main/clients/client_forward_refs/expected_client/get_authenticated_user.py diff --git a/tests/main/clients/no_global_imports/expected_client/get_complex_scalar.py b/tests/main/clients/client_forward_refs/expected_client/get_complex_scalar.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/get_complex_scalar.py rename to tests/main/clients/client_forward_refs/expected_client/get_complex_scalar.py diff --git a/tests/main/clients/no_global_imports/expected_client/get_simple_scalar.py b/tests/main/clients/client_forward_refs/expected_client/get_simple_scalar.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/get_simple_scalar.py rename to tests/main/clients/client_forward_refs/expected_client/get_simple_scalar.py diff --git a/tests/main/clients/no_global_imports/expected_client/input_types.py b/tests/main/clients/client_forward_refs/expected_client/input_types.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/input_types.py rename to tests/main/clients/client_forward_refs/expected_client/input_types.py diff --git a/tests/main/clients/no_global_imports/expected_client/list_animals.py b/tests/main/clients/client_forward_refs/expected_client/list_animals.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/list_animals.py rename to tests/main/clients/client_forward_refs/expected_client/list_animals.py diff --git a/tests/main/clients/no_global_imports/expected_client/list_strings_1.py b/tests/main/clients/client_forward_refs/expected_client/list_strings_1.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/list_strings_1.py rename to tests/main/clients/client_forward_refs/expected_client/list_strings_1.py diff --git a/tests/main/clients/no_global_imports/expected_client/list_strings_2.py b/tests/main/clients/client_forward_refs/expected_client/list_strings_2.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/list_strings_2.py rename to tests/main/clients/client_forward_refs/expected_client/list_strings_2.py diff --git a/tests/main/clients/no_global_imports/expected_client/list_strings_3.py b/tests/main/clients/client_forward_refs/expected_client/list_strings_3.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/list_strings_3.py rename to tests/main/clients/client_forward_refs/expected_client/list_strings_3.py diff --git a/tests/main/clients/no_global_imports/expected_client/list_strings_4.py b/tests/main/clients/client_forward_refs/expected_client/list_strings_4.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/list_strings_4.py rename to tests/main/clients/client_forward_refs/expected_client/list_strings_4.py diff --git a/tests/main/clients/no_global_imports/expected_client/list_type_a.py b/tests/main/clients/client_forward_refs/expected_client/list_type_a.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/list_type_a.py rename to tests/main/clients/client_forward_refs/expected_client/list_type_a.py diff --git a/tests/main/clients/no_global_imports/expected_client/subscribe_strings.py b/tests/main/clients/client_forward_refs/expected_client/subscribe_strings.py similarity index 100% rename from tests/main/clients/no_global_imports/expected_client/subscribe_strings.py rename to tests/main/clients/client_forward_refs/expected_client/subscribe_strings.py diff --git a/tests/main/clients/client_forward_refs/expected_client/unwrap_fragment.py b/tests/main/clients/client_forward_refs/expected_client/unwrap_fragment.py new file mode 100644 index 00000000..83423282 --- /dev/null +++ b/tests/main/clients/client_forward_refs/expected_client/unwrap_fragment.py @@ -0,0 +1,5 @@ +from .client_forward_refs_fragments import FragmentWithSingleField + + +class UnwrapFragment(FragmentWithSingleField): + pass diff --git a/tests/main/clients/no_global_imports/pyproject.toml b/tests/main/clients/client_forward_refs/pyproject.toml similarity index 69% rename from tests/main/clients/no_global_imports/pyproject.toml rename to tests/main/clients/client_forward_refs/pyproject.toml index f3dc5453..6f4570ba 100644 --- a/tests/main/clients/no_global_imports/pyproject.toml +++ b/tests/main/clients/client_forward_refs/pyproject.toml @@ -2,10 +2,10 @@ schema_path = "schema.graphql" queries_path = "queries.graphql" include_comments = "none" -target_package_name = "no_global_imports" +target_package_name = "client_forward_refs" files_to_include = ["custom_scalars.py"] -fragments_module_name = "no_global_imports_fragments" -plugins = ["ariadne_codegen.contrib.no_global_imports.NoGlobalImportsPlugin"] +fragments_module_name = "client_forward_refs_fragments" +plugins = ["ariadne_codegen.contrib.client_forward_refs.ClientForwardRefsPlugin"] [tool.ariadne-codegen.scalars.SimpleScalar] type = ".custom_scalars.SimpleScalar" diff --git a/tests/main/clients/no_global_imports/queries.graphql b/tests/main/clients/client_forward_refs/queries.graphql similarity index 100% rename from tests/main/clients/no_global_imports/queries.graphql rename to tests/main/clients/client_forward_refs/queries.graphql diff --git a/tests/main/clients/no_global_imports/schema.graphql b/tests/main/clients/client_forward_refs/schema.graphql similarity index 100% rename from tests/main/clients/no_global_imports/schema.graphql rename to tests/main/clients/client_forward_refs/schema.graphql diff --git a/tests/main/clients/no_global_imports_shorter_results/custom_scalars.py b/tests/main/clients/client_forward_refs_shorter_results/custom_scalars.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/custom_scalars.py rename to tests/main/clients/client_forward_refs_shorter_results/custom_scalars.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/__init__.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/__init__.py similarity index 97% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/__init__.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/__init__.py index 6c15fb9c..51ff9a6c 100644 --- a/tests/main/clients/no_global_imports_shorter_results/expected_client/__init__.py +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/__init__.py @@ -1,6 +1,12 @@ from .async_base_client import AsyncBaseClient from .base_model import BaseModel, Upload from .client import Client +from .client_forward_refs_shorter_resultsfragments import ( + FragmentWithSingleField, + FragmentWithSingleFieldQueryUnwrapFragment, + ListAnimalsFragment, + ListAnimalsFragmentListAnimals, +) from .exceptions import ( GraphQLClientError, GraphQLClientGraphQLError, @@ -29,12 +35,6 @@ from .list_strings_3 import ListStrings3 from .list_strings_4 import ListStrings4 from .list_type_a import ListTypeA, ListTypeAListOptionalTypeA -from .no_global_imports_shorter_resultsfragments import ( - FragmentWithSingleField, - FragmentWithSingleFieldQueryUnwrapFragment, - ListAnimalsFragment, - ListAnimalsFragmentListAnimals, -) from .subscribe_strings import SubscribeStrings from .unwrap_fragment import UnwrapFragment diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/async_base_client.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/async_base_client.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/async_base_client.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/async_base_client.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/base_model.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/base_model.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/base_model.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/base_model.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/client.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/client.py similarity index 99% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/client.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/client.py index 655dbb50..3b388490 100644 --- a/tests/main/clients/no_global_imports_shorter_results/expected_client/client.py +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/client.py @@ -3,6 +3,9 @@ from .async_base_client import AsyncBaseClient if TYPE_CHECKING: + from .client_forward_refs_shorter_resultsfragments import ( + FragmentWithSingleFieldQueryUnwrapFragment, + ) from .custom_scalars import ComplexScalar, SimpleScalar from .get_animal_by_name import ( GetAnimalByNameAnimalByNameAnimal, @@ -17,9 +20,6 @@ ListAnimalsListAnimalsDog, ) from .list_type_a import ListTypeAListOptionalTypeA - from .no_global_imports_shorter_resultsfragments import ( - FragmentWithSingleFieldQueryUnwrapFragment, - ) def gql(q: str) -> str: diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/no_global_imports_shorter_resultsfragments.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/client_forward_refs_shorter_resultsfragments.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/no_global_imports_shorter_resultsfragments.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/client_forward_refs_shorter_resultsfragments.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/custom_scalars.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/custom_scalars.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/custom_scalars.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/custom_scalars.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/enums.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/enums.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/enums.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/enums.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/exceptions.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/exceptions.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/exceptions.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/exceptions.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/get_animal_by_name.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_animal_by_name.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/get_animal_by_name.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/get_animal_by_name.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/get_animal_fragment_with_extra.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_animal_fragment_with_extra.py similarity index 67% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/get_animal_fragment_with_extra.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/get_animal_fragment_with_extra.py index 90ffbdd1..2c56da12 100644 --- a/tests/main/clients/no_global_imports_shorter_results/expected_client/get_animal_fragment_with_extra.py +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_animal_fragment_with_extra.py @@ -2,7 +2,7 @@ from pydantic import Field -from .no_global_imports_shorter_resultsfragments import ListAnimalsFragment +from .client_forward_refs_shorter_resultsfragments import ListAnimalsFragment class GetAnimalFragmentWithExtra(ListAnimalsFragment): diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/get_authenticated_user.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_authenticated_user.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/get_authenticated_user.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/get_authenticated_user.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/get_complex_scalar.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_complex_scalar.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/get_complex_scalar.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/get_complex_scalar.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/get_simple_scalar.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/get_simple_scalar.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/get_simple_scalar.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/get_simple_scalar.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/input_types.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/input_types.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/input_types.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/input_types.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/list_animals.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_animals.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/list_animals.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/list_animals.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_1.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_1.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_1.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_1.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_2.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_2.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_2.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_2.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_3.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_3.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_3.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_3.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_4.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_4.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/list_strings_4.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/list_strings_4.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/list_type_a.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/list_type_a.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/list_type_a.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/list_type_a.py diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/subscribe_strings.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/subscribe_strings.py similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/expected_client/subscribe_strings.py rename to tests/main/clients/client_forward_refs_shorter_results/expected_client/subscribe_strings.py diff --git a/tests/main/clients/client_forward_refs_shorter_results/expected_client/unwrap_fragment.py b/tests/main/clients/client_forward_refs_shorter_results/expected_client/unwrap_fragment.py new file mode 100644 index 00000000..e3512512 --- /dev/null +++ b/tests/main/clients/client_forward_refs_shorter_results/expected_client/unwrap_fragment.py @@ -0,0 +1,5 @@ +from .client_forward_refs_shorter_resultsfragments import FragmentWithSingleField + + +class UnwrapFragment(FragmentWithSingleField): + pass diff --git a/tests/main/clients/no_global_imports_shorter_results/pyproject.toml b/tests/main/clients/client_forward_refs_shorter_results/pyproject.toml similarity index 71% rename from tests/main/clients/no_global_imports_shorter_results/pyproject.toml rename to tests/main/clients/client_forward_refs_shorter_results/pyproject.toml index 9379191a..1bc91372 100644 --- a/tests/main/clients/no_global_imports_shorter_results/pyproject.toml +++ b/tests/main/clients/client_forward_refs_shorter_results/pyproject.toml @@ -2,12 +2,12 @@ schema_path = "schema.graphql" queries_path = "queries.graphql" include_comments = "none" -target_package_name = "no_global_imports_shorter_results" +target_package_name = "client_forward_refs_shorter_results" files_to_include = ["custom_scalars.py"] -fragments_module_name = "no_global_imports_shorter_resultsfragments" +fragments_module_name = "client_forward_refs_shorter_resultsfragments" plugins = [ "ariadne_codegen.contrib.shorter_results.ShorterResultsPlugin", - "ariadne_codegen.contrib.no_global_imports.NoGlobalImportsPlugin" + "ariadne_codegen.contrib.client_forward_refs.ClientForwardRefsPlugin" ] [tool.ariadne-codegen.scalars.SimpleScalar] diff --git a/tests/main/clients/no_global_imports_shorter_results/queries.graphql b/tests/main/clients/client_forward_refs_shorter_results/queries.graphql similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/queries.graphql rename to tests/main/clients/client_forward_refs_shorter_results/queries.graphql diff --git a/tests/main/clients/no_global_imports_shorter_results/schema.graphql b/tests/main/clients/client_forward_refs_shorter_results/schema.graphql similarity index 100% rename from tests/main/clients/no_global_imports_shorter_results/schema.graphql rename to tests/main/clients/client_forward_refs_shorter_results/schema.graphql diff --git a/tests/main/clients/no_global_imports/expected_client/unwrap_fragment.py b/tests/main/clients/no_global_imports/expected_client/unwrap_fragment.py deleted file mode 100644 index ad17ac32..00000000 --- a/tests/main/clients/no_global_imports/expected_client/unwrap_fragment.py +++ /dev/null @@ -1,5 +0,0 @@ -from .no_global_imports_fragments import FragmentWithSingleField - - -class UnwrapFragment(FragmentWithSingleField): - pass diff --git a/tests/main/clients/no_global_imports_shorter_results/expected_client/unwrap_fragment.py b/tests/main/clients/no_global_imports_shorter_results/expected_client/unwrap_fragment.py deleted file mode 100644 index 8189e91f..00000000 --- a/tests/main/clients/no_global_imports_shorter_results/expected_client/unwrap_fragment.py +++ /dev/null @@ -1,5 +0,0 @@ -from .no_global_imports_shorter_resultsfragments import FragmentWithSingleField - - -class UnwrapFragment(FragmentWithSingleField): - pass From 51c6a428ae418ddd1361edea06e9be8afeca7b15 Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Mon, 25 Mar 2024 21:57:19 +0100 Subject: [PATCH 07/11] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rafał Pitoń --- README.md | 2 +- ariadne_codegen/contrib/client_forward_refs.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 79969b90..c7d4a916 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ Ariadne Codegen ships with optional plugins importable from the `ariadne_codegen - [`ariadne_codegen.contrib.extract_operations.ExtractOperationsPlugin`](ariadne_codegen/contrib/extract_operations.py) - This extracts query strings from generated client's methods into separate `operations.py` module. It also modifies the generated client to import these definitions. Generated module name can be customized by adding `operations_module_name="custom_name"` to the `[tool.ariadne-codegen.operations]` section in config. Eg.: -- [`ariadne_codegen.contrib.client_forward_refs.ClientForwardRefsPlugin`](ariadne_codegen/contrib/client_forward_refs.py) - This plugin changes generated client module moving all Pydantic models imports under the `TYPE_CHECKING` condition, making them forward references. This greatly improves the import performance for `client` module. +- [`ariadne_codegen.contrib.client_forward_refs.ClientForwardRefsPlugin`](ariadne_codegen/contrib/client_forward_refs.py) - This plugin changes generated client module moving all Pydantic models imports under the `TYPE_CHECKING` condition, making them forward references. This greatly improves the import performance of the `client` module. ```toml [tool.ariadne-codegen] diff --git a/ariadne_codegen/contrib/client_forward_refs.py b/ariadne_codegen/contrib/client_forward_refs.py index f7855f0b..c748a122 100644 --- a/ariadne_codegen/contrib/client_forward_refs.py +++ b/ariadne_codegen/contrib/client_forward_refs.py @@ -1,5 +1,15 @@ """ -Plugin to only import types for GraphQL responses when you call methods. +Plugin that delays imports of Pydantic models in client module. + +Puts all imports under the `typing.TYPE_CHECKING` flag, making +type annotations for generated client's methods forward references. + +This greatly improves import time of generated `client` module when +there are many Pydantic models. + +Because generated client's methods need type definitions for models +they are using, those models imports will be also inserted in their +bodies. This will massively reduce import times for larger projects since you only have to load the input types when loading the client. From 967e5247af03bb681e4e1e7cf21c835278d434cb Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Mon, 25 Mar 2024 21:58:14 +0100 Subject: [PATCH 08/11] Sort exports --- ariadne_codegen/contrib/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ariadne_codegen/contrib/__init__.py b/ariadne_codegen/contrib/__init__.py index 6b3d6b2a..c9e56fec 100644 --- a/ariadne_codegen/contrib/__init__.py +++ b/ariadne_codegen/contrib/__init__.py @@ -4,8 +4,8 @@ from .shorter_results import ShorterResultsPlugin __all__ = [ + "ClientForwardRefsPlugin", "ExtractOperationsPlugin", "NoReimportsPlugin", "ShorterResultsPlugin", - "ClientForwardRefsPlugin", ] From e6da1280aaa653d21f1a5459f1cac7d3355c7d59 Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Tue, 2 Apr 2024 13:14:06 +0200 Subject: [PATCH 09/11] Update ariadne_codegen/contrib/client_forward_refs.py Co-authored-by: DamianCzajkowski <43958031+DamianCzajkowski@users.noreply.github.com> --- ariadne_codegen/contrib/client_forward_refs.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ariadne_codegen/contrib/client_forward_refs.py b/ariadne_codegen/contrib/client_forward_refs.py index c748a122..522ef7d4 100644 --- a/ariadne_codegen/contrib/client_forward_refs.py +++ b/ariadne_codegen/contrib/client_forward_refs.py @@ -279,9 +279,7 @@ def _update_imports(self, module: ast.Module): # unless that's the type we return. This behaviour can differ if you use # a plugin such as `ShorterResultsPlugin` that will import a type that # is different from the type returned. - return_types_not_used_as_input.update( - {k for k in self.imported_in_method if k not in self.input_and_return_types} - ) + return_types_not_used_as_input |= self.imported_in_method - self.input_and_return_types if len(return_types_not_used_as_input) == 0: return None From 57916193f005256e9770da5b65c3423c5dc6ac51 Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Tue, 2 Apr 2024 16:56:37 +0200 Subject: [PATCH 10/11] Re-format after PR suggestion --- ariadne_codegen/contrib/client_forward_refs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ariadne_codegen/contrib/client_forward_refs.py b/ariadne_codegen/contrib/client_forward_refs.py index 522ef7d4..22780997 100644 --- a/ariadne_codegen/contrib/client_forward_refs.py +++ b/ariadne_codegen/contrib/client_forward_refs.py @@ -279,7 +279,9 @@ def _update_imports(self, module: ast.Module): # unless that's the type we return. This behaviour can differ if you use # a plugin such as `ShorterResultsPlugin` that will import a type that # is different from the type returned. - return_types_not_used_as_input |= self.imported_in_method - self.input_and_return_types + return_types_not_used_as_input |= ( + self.imported_in_method - self.input_and_return_types + ) if len(return_types_not_used_as_input) == 0: return None From 4e068458226267917fe649ee2f4a627aef9b7ae0 Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Tue, 2 Apr 2024 17:07:20 +0200 Subject: [PATCH 11/11] Split `_update_imports` --- .../contrib/client_forward_refs.py | 52 +++++++++++++------ 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/ariadne_codegen/contrib/client_forward_refs.py b/ariadne_codegen/contrib/client_forward_refs.py index 22780997..d3ac9833 100644 --- a/ariadne_codegen/contrib/client_forward_refs.py +++ b/ariadne_codegen/contrib/client_forward_refs.py @@ -251,18 +251,13 @@ def _get_class_from_call(self, call: ast.Call) -> Optional[ast.Name]: return call.func.value - def _update_imports(self, module: ast.Module): + def _update_imports(self, module: ast.Module) -> None: """Update all imports. Iterate over all imports and remove the aliases that we use as input or return value. These will be moved and added to an `if TYPE_CHECKING` block. - **NOTE** If an `ast.ImportFrom` ends up without any names we must remove - it completely otherwise formatting will not work (it would remove the - empty `import from` but not format the rest of the code without running - it twice). - We do this by storing all imports that we want to keep in an array, we then drop all from the body and re-insert the ones to keep. Lastly we import `TYPE_CHECKING` and add all our imports in the `if TYPE_CHECKING` @@ -286,12 +281,29 @@ def _update_imports(self, module: ast.Module): if len(return_types_not_used_as_input) == 0: return None - # We sadly have to iterate over all imports again and remove the imports - # we will do conditionally. - # It's very important that we get this right, if we keep any - # `ImportFrom` that ends up without any names, the formatting will not - # work! It will only remove the empty `import from` but not other unused - # imports. + non_empty_imports = self._update_existing_imports( + module, return_types_not_used_as_input + ) + self._add_forward_ref_imports(module, non_empty_imports) + + return None + + def _update_existing_imports( + self, module: ast.Module, return_types_not_used_as_input: set[str] + ) -> List[Union[ast.Import, ast.ImportFrom]]: + """Update existing imports. + + Remove all import or import from statements that would otherwise be + useless after moving them to forward refs. + + It's very important that we get this right, if we keep any `ImportFrom` + that ends up without any names, the formatting will not work! It will + only remove the empty `import from` but not other unused imports. + + :param module: The ast module to update + :param return_types_not_used_as_input: Set of return types not used as + input + """ non_empty_imports: List[Union[ast.Import, ast.ImportFrom]] = [] last_import_at = 0 for i, node in enumerate(module.body): @@ -316,8 +328,18 @@ def _update_imports(self, module: ast.Module): # We can now remove all imports and re-insert the ones that's not empty. module.body = non_empty_imports + module.body[last_import_at + 1 :] - # Create import to use for type checking. These will be put in an `if - # TYPE_CHECKING` block. + return non_empty_imports + + def _add_forward_ref_imports( + self, + module: ast.Module, + non_empty_imports: List[Union[ast.Import, ast.ImportFrom]], + ) -> None: + """Add forward ref imports. + + Add all the forward ref imports meaning all the types needed for type + checking under the `if TYPE_CHECKING` condition. + """ type_checking_imports = {} for cls in self.input_and_return_types: module_name = self.imported_classes[cls] @@ -345,8 +367,6 @@ def _update_imports(self, module: ast.Module): ), ) - return None - def _update_name_to_constant(self, node: ast.expr) -> ast.expr: """Update return types.