Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

305 update ast usage in code for python 312 changes #306

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# CHANGELOG

## 0.14.1 (UNRELEASED)

- Changed code typing to satisfy MyPy 1.11.0 version


## 0.14.0 (2024-07-17)

- Added `ClientForwardRefsPlugin` to standard plugins.
Expand Down
2 changes: 1 addition & 1 deletion ariadne_codegen/client_generators/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ def _generate_variables_assign(
self, variable_names: Dict[str, str], arguments_dict: ast.Dict, lineno: int = 1
) -> ast.AnnAssign:
return generate_ann_assign(
target=variable_names[self._variables_dict_variable],
target=generate_name(variable_names[self._variables_dict_variable]),
annotation=generate_subscript(
generate_name(DICT),
generate_tuple([generate_name("str"), generate_name("object")]),
Expand Down
6 changes: 3 additions & 3 deletions ariadne_codegen/client_generators/custom_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def generate_clear_arguments_section(
) -> Tuple[List[ast.stmt], List[ast.keyword]]:
arguments_body = [
generate_ann_assign(
"arguments",
generate_name("arguments"),
generate_subscript(
generate_name(DICT),
generate_tuple(
Expand All @@ -240,8 +240,8 @@ def generate_clear_arguments_section(
),
),
generate_dict(
return_arguments_keys,
return_arguments_values, # type: ignore
return_arguments_keys, # type: ignore
return_arguments_values,
),
),
generate_assign(
Expand Down
2 changes: 1 addition & 1 deletion ariadne_codegen/client_generators/custom_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _generate_class_field(
name, field_name, getattr(field, "args")
)
return generate_ann_assign(
target=name,
target=generate_name(name),
annotation=generate_name(f'"{field_name}"'),
value=generate_call(
func=generate_name(field_name), args=[generate_constant(org_name)]
Expand Down
44 changes: 25 additions & 19 deletions ariadne_codegen/client_generators/input_fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ast
from typing import Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple, cast

from graphql import (
BooleanValueNode,
Expand Down Expand Up @@ -142,15 +142,18 @@ def parse_input_const_value_node(

if isinstance(node, ListValueNode):
list_ = generate_list(
[
parse_input_const_value_node(
node=v,
field_type=field_type,
nested_object=nested_object,
nested_list=True,
)
for v in node.values
]
cast(
List[ast.expr],
[
parse_input_const_value_node(
node=v,
field_type=field_type,
nested_object=nested_object,
nested_list=True,
)
for v in node.values
],
)
)
if not nested_list:
return generate_call(
Expand All @@ -166,15 +169,18 @@ def parse_input_const_value_node(
if isinstance(node, ObjectValueNode):
dict_ = generate_dict(
keys=[generate_constant(f.name.value) for f in node.fields],
values=[
parse_input_const_value_node(
node=f.value,
field_type=field_type,
nested_object=True,
nested_list=True,
)
for f in node.fields
],
values=cast(
List[ast.expr],
[
parse_input_const_value_node(
node=f.value,
field_type=field_type,
nested_object=True,
nested_list=True,
)
for f in node.fields
],
),
)
if not nested_object:
return generate_call(
Expand Down
3 changes: 2 additions & 1 deletion ariadne_codegen/client_generators/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
generate_keyword,
generate_method_call,
generate_module,
generate_name,
generate_pydantic_field,
model_has_forward_refs,
)
Expand Down Expand Up @@ -172,7 +173,7 @@ def _parse_input_definition(
field.type, custom_scalars=self.custom_scalars
)
field_implementation = generate_ann_assign(
target=name,
target=generate_name(name),
annotation=annotation,
value=parse_input_field_default_value(
node=field.ast_node, annotation=annotation, field_type=field_type
Expand Down
6 changes: 4 additions & 2 deletions ariadne_codegen/client_generators/result_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ def parse_interface_type(
)
context.abstract_type = True
if inline_fragments or fragments_on_subtypes:
types = [generate_annotation_name('"' + class_name + type_.name + '"', False)]
types: List[ast.expr] = [
generate_annotation_name('"' + class_name + type_.name + '"', False)
]
context.related_classes.append(
RelatedClassData(class_name=class_name + type_.name, type_name=type_.name)
)
Expand Down Expand Up @@ -275,7 +277,7 @@ def parse_union_type(
class_name: str,
) -> Annotation:
context.abstract_type = True
sub_annotations = [
sub_annotations: List[ast.expr] = [
parse_operation_field_type(
type_=subtype,
context=context,
Expand Down
3 changes: 2 additions & 1 deletion ariadne_codegen/client_generators/result_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
generate_import_from,
generate_method_call,
generate_module,
generate_name,
generate_pass,
generate_pydantic_field,
model_has_forward_refs,
Expand Down Expand Up @@ -264,7 +265,7 @@ def _parse_type_definition(
)

field_implementation = generate_ann_assign(
target=name,
target=generate_name(name),
annotation=annotation,
lineno=lineno,
value=default_value,
Expand Down
101 changes: 59 additions & 42 deletions ariadne_codegen/codegen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
from typing import Any, Dict, List, Optional, Union
import sys
from typing import Any, Dict, List, Optional, Union, cast

from graphql import (
GraphQLEnumType,
Expand All @@ -24,11 +25,6 @@
from .exceptions import ParsingError


def generate_import(names: List[str], level: int = 0) -> ast.Import:
"""Generate import statement."""
return ast.Import(names=[ast.alias(n) for n in names], level=level)


def generate_import_from(
names: List[str], from_: Optional[str] = None, level: int = 0
) -> ast.ImportFrom:
Expand Down Expand Up @@ -94,17 +90,21 @@ def generate_async_method_definition(
return_type: Union[ast.Name, ast.Subscript],
body: Optional[List[ast.stmt]] = None,
lineno: int = 1,
decorator_list: Optional[List[ast.Name]] = None,
decorator_list: Optional[List[ast.expr]] = None,
) -> ast.AsyncFunctionDef:
"""Generate async function."""
return ast.AsyncFunctionDef(
name=name,
args=arguments,
body=body if body else [ast.Pass()],
decorator_list=decorator_list if decorator_list else [],
returns=return_type,
lineno=lineno,
)
params: Dict[str, Any] = {
"name": name,
"args": arguments,
"body": body if body else [ast.Pass()],
"decorator_list": decorator_list if decorator_list else [],
"returns": return_type,
"lineno": lineno,
}
if sys.version_info >= (3, 12):
params["type_params"] = []

return ast.AsyncFunctionDef(**params)


def generate_class_def(
Expand All @@ -113,14 +113,20 @@ def generate_class_def(
body: Optional[List[ast.stmt]] = None,
) -> ast.ClassDef:
"""Generate class definition."""
bases = [ast.Name(id=name) for name in base_names] if base_names else []
return ast.ClassDef(
name=name,
bases=bases,
keywords=[],
body=body if body else [],
decorator_list=[],
bases = cast(
List[ast.expr], [ast.Name(id=name) for name in base_names] if base_names else []
)
params: Dict[str, Any] = {
"name": name,
"bases": bases,
"keywords": [],
"body": body if body else [],
"decorator_list": [],
}
if sys.version_info >= (3, 12):
params["type_params"] = []

return ast.ClassDef(**params)


def generate_name(name: str) -> ast.Name:
Expand Down Expand Up @@ -153,37 +159,39 @@ def generate_assign(
) -> ast.Assign:
"""Generate assign object."""
return ast.Assign(
targets=[ast.Name(t) for t in targets], value=value, lineno=lineno
targets=[ast.Name(t) for t in targets],
value=value, # type:ignore
lineno=lineno,
)


def generate_ann_assign(
target: Union[str, ast.expr],
target: Union[ast.Name, ast.Attribute, ast.Subscript],
annotation: Annotation,
value: Optional[ast.expr] = None,
lineno: int = 1,
) -> ast.AnnAssign:
"""Generate ann assign object."""
return ast.AnnAssign(
target=target if isinstance(target, ast.expr) else ast.Name(id=target),
target=target,
annotation=annotation,
simple=1,
value=value,
simple=1,
lineno=lineno,
)


def generate_union_annotation(
types: List[Union[ast.Name, ast.Subscript]], nullable: bool = True
types: List[ast.expr], nullable: bool = True
) -> ast.Subscript:
"""Generate union annotation."""
result = ast.Subscript(value=ast.Name(id=UNION), slice=ast.Tuple(elts=types))
return result if not nullable else generate_nullable_annotation(result)


def generate_dict(
keys: Optional[List[ast.expr]] = None,
values: Optional[List[Optional[ast.expr]]] = None,
keys: Optional[List[Optional[ast.expr]]] = None,
values: Optional[List[ast.expr]] = None,
) -> ast.Dict:
"""Generate dict object."""
return ast.Dict(keys=keys if keys else [], values=values if values else [])
Expand All @@ -201,7 +209,9 @@ def generate_call(
) -> ast.Call:
"""Generate call object."""
return ast.Call(
func=func, args=args if args else [], keywords=keywords if keywords else []
func=func,
args=args if args else [], # type:ignore
keywords=keywords if keywords else [],
)


Expand Down Expand Up @@ -240,7 +250,10 @@ def parse_field_type(
return generate_annotation_name('"' + type_.name + '"', nullable)

if isinstance(type_, GraphQLUnionType):
subtypes = [parse_field_type(subtype, False) for subtype in type_.types]
subtypes = cast(
List[ast.expr],
[parse_field_type(subtype, False) for subtype in type_.types],
)
return generate_union_annotation(subtypes, nullable)

if isinstance(type_, GraphQLList):
Expand All @@ -255,7 +268,7 @@ def parse_field_type(


def generate_method_call(
object_name: str, method_name: str, args: Optional[List[Optional[ast.expr]]] = None
object_name: str, method_name: str, args: Optional[List[ast.expr]] = None
) -> ast.Call:
"""Generate object`s method call."""
return ast.Call(
Expand Down Expand Up @@ -287,7 +300,7 @@ def generate_trivial_lambda(name: str, argument_name: str) -> ast.Assign:
)


def generate_list(elements: List[Optional[ast.expr]]) -> ast.List:
def generate_list(elements: List[ast.expr]) -> ast.List:
"""Generate list object."""
return ast.List(elts=elements)

Expand Down Expand Up @@ -343,16 +356,20 @@ def generate_method_definition(
return_type: Union[ast.Name, ast.Subscript],
body: Optional[List[ast.stmt]] = None,
lineno: int = 1,
decorator_list: Optional[List[ast.Name]] = None,
decorator_list: Optional[List[ast.expr]] = None,
) -> ast.FunctionDef:
return ast.FunctionDef(
name=name,
args=arguments,
body=body if body else [ast.Pass()],
decorator_list=decorator_list if decorator_list else [],
returns=return_type,
lineno=lineno,
)
params: Dict[str, Any] = {
"name": name,
"args": arguments,
"body": body if body else [ast.Pass()],
"decorator_list": decorator_list if decorator_list else [],
"returns": return_type,
"lineno": lineno,
}
if sys.version_info >= (3, 12):
params["type_params"] = []

return ast.FunctionDef(**params)


def generate_async_for(
Expand Down
Loading
Loading