Skip to content

Commit

Permalink
Add custom operation generation (#296)
Browse files Browse the repository at this point in the history
Add custom operation generation
  • Loading branch information
DamianCzajkowski authored Jun 11, 2024
1 parent cb576f8 commit a82ecb8
Show file tree
Hide file tree
Showing 40 changed files with 3,948 additions and 14 deletions.
203 changes: 201 additions & 2 deletions ariadne_codegen/client_generators/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
generate_await,
generate_call,
generate_class_def,
generate_comp,
generate_constant,
generate_expr,
generate_import_from,
generate_keyword,
generate_list,
generate_list_comp,
generate_method_definition,
generate_module,
generate_name,
Expand All @@ -32,11 +35,20 @@
from .constants import (
ANY,
ASYNC_ITERATOR,
BASE_GRAPHQL_FIELD_CLASS_NAME,
BASE_OPERATION_FILE_PATH,
DICT,
DOCUMENT_NODE,
GRAPHQL_MODULE,
KWARGS_NAMES,
LIST,
MODEL_VALIDATE_METHOD,
NAME_NODE,
OPERATION_DEFINITION_NODE,
OPERATION_TYPE,
OPTIONAL,
PRINT_AST,
SELECTION_SET_NODE,
TYPING_MODULE,
UNION,
UNSET_IMPORT,
Expand Down Expand Up @@ -66,10 +78,18 @@ def __init__(
self.custom_scalars = custom_scalars if custom_scalars else {}
self.arguments_generator = arguments_generator

self._imports: List[ast.ImportFrom] = []
self._imports: List[Union[ast.ImportFrom, ast.Import]] = []
self._add_import(
generate_import_from(
[OPTIONAL, LIST, DICT, ANY, UNION, ASYNC_ITERATOR], TYPING_MODULE
[
OPTIONAL,
LIST,
DICT,
ANY,
UNION,
ASYNC_ITERATOR,
],
TYPING_MODULE,
)
)
self._add_import(base_client_import)
Expand Down Expand Up @@ -187,6 +207,185 @@ def add_method(
generate_import_from(names=[return_type], from_=return_type_module, level=1)
)

def add_execute_custom_operation_method(self):
self._add_import(
generate_import_from(
[
DOCUMENT_NODE,
OPERATION_DEFINITION_NODE,
NAME_NODE,
SELECTION_SET_NODE,
PRINT_AST,
],
GRAPHQL_MODULE,
)
)
self._add_import(
generate_import_from(
[BASE_GRAPHQL_FIELD_CLASS_NAME], BASE_OPERATION_FILE_PATH.stem, level=1
)
)
execute_await = generate_await(
value=generate_call(
func=generate_attribute(value=generate_name("self"), attr="execute"),
args=[
generate_call(
func=generate_name("print_ast"),
args=[generate_name("operation_ast")],
)
],
keywords=[
generate_keyword(
arg="operation_name", value=generate_name("operation_name")
)
],
)
)

operation_definition_node = generate_call(
func=generate_name("OperationDefinitionNode"),
keywords=[
generate_keyword(
arg="operation", value=generate_name("operation_type")
),
generate_keyword(
arg="name",
value=generate_call(
func=generate_name("NameNode"),
keywords=[
generate_keyword(
arg="value", value=generate_name("operation_name")
)
],
),
),
generate_keyword(
arg="selection_set",
value=generate_call(
func=generate_name("SelectionSetNode"),
keywords=[
generate_keyword(
arg="selections",
value=generate_list_comp(
elt=generate_call(
func=generate_attribute(
value=generate_name("field"),
attr="to_ast",
),
),
generators=[
generate_comp(
target="field",
iter_="fields",
)
],
),
)
],
),
),
],
)
operation_ast = generate_call(
func=generate_name("DocumentNode"),
keywords=[
generate_keyword(
arg="definitions",
value=generate_list(elements=[operation_definition_node]),
)
],
)
body_return = generate_return(
value=generate_call(
func=generate_attribute(value=generate_name("self"), attr="get_data"),
args=[generate_name("response")],
)
)
async_def_node = generate_async_method_definition(
name="execute_custom_operation",
arguments=generate_arguments(
args=[
generate_arg("self"),
generate_arg(
"*fields",
annotation=generate_name("GraphQLField"),
),
generate_arg(
"operation_type",
annotation=generate_name("OperationType"),
),
generate_arg("operation_name", annotation=generate_name("str")),
],
),
body=[
generate_assign(
targets=["operation_ast"],
value=operation_ast,
),
generate_assign(
targets=["response"],
value=execute_await,
),
body_return,
],
return_type=generate_subscript(
generate_name(DICT),
generate_tuple([generate_name("str"), generate_name("Any")]),
),
)
self._class_def.body.append(async_def_node)

def create_custom_operation_method(self, name, operation_type):
self._add_import(
generate_import_from(
[
OPERATION_TYPE,
],
GRAPHQL_MODULE,
)
)
body_return = generate_return(
value=generate_await(
value=generate_call(
func=generate_attribute(
value=generate_name("self"),
attr="execute_custom_operation",
),
args=[
generate_name("*fields"),
],
keywords=[
generate_keyword(
arg="operation_type",
value=generate_attribute(
value=generate_name("OperationType"),
attr=operation_type,
),
),
generate_keyword(
arg="operation_name", value=generate_name("operation_name")
),
],
)
)
)
async_def_query = generate_async_method_definition(
name=name,
arguments=generate_arguments(
args=[
generate_arg("self"),
generate_arg("*fields", annotation=generate_name("GraphQLField")),
generate_arg("operation_name", annotation=generate_name("str")),
],
),
body=[body_return],
return_type=generate_subscript(
generate_name(DICT),
generate_tuple([generate_name("str"), generate_name("Any")]),
),
)
self._class_def.body.append(async_def_query)

def get_variable_names(self, arguments: ast.arguments) -> Dict[str, str]:
mapped_variable_names = [
self._operation_str_variable,
Expand Down
20 changes: 20 additions & 0 deletions ariadne_codegen/client_generators/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,34 @@
LIST = "List"
UNION = "Union"
ANY = "Any"
TYPE = "Type"
TYPE_CHECKING = "TYPE_CHECKING"
DICT = "Dict"
CALLABLE = "Callable"
ANNOTATED = "Annotated"
LITERAL = "Literal"
ASYNC_ITERATOR = "AsyncIterator"
DOCUMENT_NODE = "DocumentNode"
OPERATION_DEFINITION_NODE = "OperationDefinitionNode"
NAME_NODE = "NameNode"
SELECTION_SET_NODE = "SelectionSetNode"
PRINT_AST = "print_ast"
OPERATION_TYPE = "OperationType"

HTTPX = "httpx"
HTTPX_RESPONSE = "httpx.Response"

TIMESTAMP_COMMENT = "# Generated by ariadne-codegen on {}"
STABLE_COMMENT = "# Generated by ariadne-codegen"
SOURCE_COMMENT = "# Source: {}"
COMMENT_DATETIME_FORMAT = "%Y-%m-%d %H:%M"

BASE_OPERATION_FILE_PATH = Path(__file__).parent / "dependencies" / "base_operation.py"
BASE_GRAPHQL_OPERATION_CLASS_NAME = "BaseGraphQLOperation"
BASE_GRAPHQL_FIELD_CLASS_NAME = "GraphQLField"
CUSTOM_FIELDS_FILE_PATH = Path(__file__).parent / "custom_fields.py"
CUSTOM_FIELDS_TYPING_FILE_PATH = Path(__file__).parent / "custom_typing_fields.py"

BASE_MODEL_FILE_PATH = Path(__file__).parent / "dependencies" / "base_model.py"
BASE_MODEL_CLASS_NAME = "BaseModel"
BASE_MODEL_IMPORT = ast.ImportFrom(
Expand All @@ -49,6 +66,7 @@
TYPENAME_ALIAS = "typename__"

TYPING_MODULE = "typing"
GRAPHQL_MODULE = "graphql"
PYDANTIC_MODULE = "pydantic"
FIELD_CLASS = "Field"
ALIAS_KEYWORD = "alias"
Expand Down Expand Up @@ -100,3 +118,5 @@

SCALARS_PARSE_DICT_NAME = "SCALARS_PARSE_FUNCTIONS"
SCALARS_SERIALIZE_DICT_NAME = "SCALARS_SERIALIZE_FUNCTIONS"

OPERATION_TYPES = ("Query", "Mutation", "Subscription")
Loading

0 comments on commit a82ecb8

Please sign in to comment.