Skip to content

Commit

Permalink
Merge pull request #244 from mirumee/only_used_enums
Browse files Browse the repository at this point in the history
Only used enums
  • Loading branch information
mat-sop authored Nov 27, 2023
2 parents b305fe4 + ed81552 commit 46bd337
Show file tree
Hide file tree
Showing 32 changed files with 471 additions and 117 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Removed `model_rebuild` calls for generated input, fragment and result models.
- Added `NoReimportsPlugin` that makes the `__init__.py` of generated client package empty.
- Added `include_all_inputs` config flag to generate only inputs used in supplied operations.
- Added `include_all_enums` config flag to generate only enums used in supplied operations.


## 0.10.0 (2023-11-15)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Optional settings:
- `include_comments` (defaults to `"stable"`) - option which sets content of comments included at the top of every generated file. Valid choices are: `"none"` (no comments), `"timestamp"` (comment with generation timestamp), `"stable"` (comment contains a message that this is a generated file)
- `convert_to_snake_case` (defaults to `true`) - a flag that specifies whether to convert fields and arguments names to snake case
- `include_all_inputs` (defaults to `true`) - a flag specifying whether to include all inputs defined in the schema, or only those used in supplied operations
- `include_all_enums` (defaults to `true`) - a flag specifying whether to include all enums defined in the schema, or only those used in supplied operations
- `async_client` (defaults to `true`) - default generated client is `async`, change this to option `false` to generate synchronous client instead
- `opentelemetry_client` (defaults to `false`) - default base clients don't support any performance tracing. Change this option to `true` to use the base client with Open Telemetry support.
- `files_to_include` (defaults to `[]`) - list of files which will be copied into generated package
Expand Down
25 changes: 21 additions & 4 deletions ariadne_codegen/client_generators/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,29 @@ def __init__(
self.schema = schema
self.plugin_manager = plugin_manager

self._generated_public_names: List[str] = []
self._imports: List[ast.ImportFrom] = [
generate_import_from([ENUM_CLASS], ENUM_MODULE)
]
self._class_defs: List[ast.ClassDef] = [
self._parse_enum_definition(d) for d in self._filter_enum_types()
]

def generate(self) -> ast.Module:
def generate(self, types_to_include: Optional[List[str]] = None) -> ast.Module:
class_defs = self._filter_class_defs(types_to_include)
self._generated_public_names = [class_def.name for class_def in class_defs]

module = generate_module(
body=cast(List[ast.stmt], self._imports)
+ cast(List[ast.stmt], self._class_defs)
body=cast(List[ast.stmt], self._imports) + cast(List[ast.stmt], class_defs)
)

if self.plugin_manager:
module = self.plugin_manager.generate_enums_module(module)

return module

def get_generated_public_names(self) -> List[str]:
return [c.name for c in self._class_defs]
return self._generated_public_names

def _filter_enum_types(self) -> List[GraphQLEnumType]:
return [
Expand All @@ -66,3 +71,15 @@ def _parse_enum_definition(self, definition: GraphQLEnumType) -> ast.ClassDef:
if self.plugin_manager:
class_def = self.plugin_manager.generate_enum(class_def, definition)
return class_def

def _filter_class_defs(
self, types_to_include: Optional[List[str]] = None
) -> List[ast.ClassDef]:
if types_to_include is None:
return self._class_defs

return [
class_def
for class_def in self._class_defs
if class_def.name in types_to_include
]
5 changes: 5 additions & 0 deletions ariadne_codegen/client_generators/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(

self._fragments_names = set(self.fragments_definitions.keys())
self._generated_public_names: List[str] = []
self._used_enums: List[str] = []

def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
class_defs_dict: Dict[str, List[ast.ClassDef]] = {}
Expand All @@ -55,6 +56,7 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
class_defs_dict[name] = generator.get_classes()
dependencies_dict[name] = generator.get_fragments_used_as_mixins()
self._generated_public_names.extend(generator.get_generated_public_names())
self._used_enums.extend(generator.get_used_enums())

sorted_class_defs = self._get_sorted_class_defs(
class_defs_dict=class_defs_dict, dependencies_dict=dependencies_dict
Expand All @@ -71,6 +73,9 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
def get_generated_public_names(self) -> List[str]:
return self._generated_public_names

def get_used_enums(self) -> List[str]:
return self._used_enums

def _get_sorted_class_defs(
self,
class_defs_dict: Dict[str, List[ast.ClassDef]],
Expand Down
17 changes: 12 additions & 5 deletions ariadne_codegen/client_generators/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,26 @@ def __init__(
upload_import,
]
self._dependencies: Dict[str, List[str]] = defaultdict(list)
self._used_enums: List[str] = []
self._used_enums: Dict[str, List[str]] = defaultdict(list)
self._used_scalars: List[str] = []
self._class_defs: List[ast.ClassDef] = [
self._parse_input_definition(d) for d in self._filter_input_types()
]
self._generated_public_names: List[str] = []

def generate(self, types_to_include: Optional[List[str]] = None) -> ast.Module:
class_defs = self._filter_class_defs(types_to_include=types_to_include)
self._generated_public_names = [class_def.name for class_def in class_defs]

if self._used_enums:
self._imports.append(
generate_import_from(self._used_enums, self.enums_module, 1)
generate_import_from(self.get_used_enums(), self.enums_module, 1)
)

for scalar_name in self._used_scalars:
scalar_data = self.custom_scalars[scalar_name]
self._imports.extend(generate_scalar_imports(scalar_data))

class_defs = self._filter_class_defs(types_to_include=types_to_include)
self._generated_public_names = [class_def.name for class_def in class_defs]
module_body = cast(List[ast.stmt], self._imports) + cast(
List[ast.stmt], class_defs
)
Expand All @@ -96,6 +97,12 @@ def generate(self, types_to_include: Optional[List[str]] = None) -> ast.Module:
def get_generated_public_names(self) -> List[str]:
return self._generated_public_names

def get_used_enums(self) -> List[str]:
enums = []
for input_name in self._generated_public_names:
enums.extend(self._used_enums[input_name])
return enums

def _filter_input_types(self) -> List[GraphQLInputObjectType]:
return [
definition
Expand Down Expand Up @@ -206,6 +213,6 @@ def _save_dependencies(self, root_type: str, field_type: str = "") -> None:
if isinstance(self.schema.type_map[field_type], GraphQLInputObjectType):
self._dependencies[root_type].append(field_type)
elif isinstance(self.schema.type_map[field_type], GraphQLEnumType):
self._used_enums.append(field_type)
self._used_enums[root_type].append(field_type)
elif isinstance(self.schema.type_map[field_type], GraphQLScalarType):
self._used_scalars.append(field_type)
19 changes: 16 additions & 3 deletions ariadne_codegen/client_generators/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
schema_source: str = "",
convert_to_snake_case: bool = True,
include_all_inputs: bool = True,
include_all_enums: bool = True,
base_model_file_path: str = BASE_MODEL_FILE_PATH.as_posix(),
base_model_import: ast.ImportFrom = BASE_MODEL_IMPORT,
upload_import: ast.ImportFrom = UPLOAD_IMPORT,
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(

self.convert_to_snake_case = convert_to_snake_case
self.include_all_inputs = include_all_inputs
self.include_all_enums = include_all_enums

self.base_model_file_path = Path(base_model_file_path)
self.base_model_import = base_model_import
Expand All @@ -111,19 +113,20 @@ def __init__(
self._result_types_files: Dict[str, ast.Module] = {}
self._generated_files: List[str] = []
self._unpacked_fragments: Set[str] = set()
self._used_enums: List[str] = []

def generate(self) -> List[str]:
"""Generate package with graphql client."""
self._include_exceptions()
self._validate_unique_file_names()
if not self.package_path.exists():
self.package_path.mkdir()
self._generate_enums()
self._generate_input_types()
self._generate_result_types()
self._generate_fragments()
self._copy_files()
self._generate_client()
self._generate_enums()
self._generate_init()

return sorted(self._generated_files)
Expand Down Expand Up @@ -157,6 +160,7 @@ def add_operation(self, definition: OperationDefinitionNode):
self._unpacked_fragments = self._unpacked_fragments.union(
query_types_generator.get_unpacked_fragments()
)
self._used_enums.extend(query_types_generator.get_used_enums())
self._result_types_files[file_name] = query_types_generator.generate()
operation_str = query_types_generator.get_operation_as_str()
self.init_generator.add_import(
Expand Down Expand Up @@ -215,7 +219,9 @@ def _generate_client(self):
code = self.plugin_manager.generate_client_code(code)
client_file_path.write_text(code)
self._generated_files.append(client_file_path.name)

self._used_enums.extend(
self.client_generator.arguments_generator.get_used_enums()
)
self.init_generator.add_import(
names=[self.client_generator.name], from_=self.client_file_name, level=1
)
Expand All @@ -232,7 +238,11 @@ def _add_comments_to_code(self, code: str, source: Optional[str] = None) -> str:
return code

def _generate_enums(self):
module = self.enums_generator.generate()
if self.include_all_enums:
module = self.enums_generator.generate()
else:
module = self.enums_generator.generate(types_to_include=self._used_enums)

code = self._add_comments_to_code(ast_to_str(module), self.schema_source)
if self.plugin_manager:
code = self.plugin_manager.generate_enums_code(code)
Expand All @@ -256,6 +266,7 @@ def _generate_input_types(self):
code = self.plugin_manager.generate_inputs_code(code)
input_types_file_path.write_text(code)
self._generated_files.append(input_types_file_path.name)
self._used_enums.extend(self.input_types_generator.get_used_enums())
self.init_generator.add_import(
self.input_types_generator.get_generated_public_names(),
self.input_types_module_name,
Expand Down Expand Up @@ -284,6 +295,7 @@ def _generate_fragments(self):
code = self._add_comments_to_code(ast_to_str(module), self.queries_source)
file_path.write_text(code)
self._generated_files.append(file_path.name)
self._used_enums.extend(self.fragments_generator.get_used_enums())
self.init_generator.add_import(
self.fragments_generator.get_generated_public_names(),
self.fragments_module_name,
Expand Down Expand Up @@ -396,6 +408,7 @@ def get_package_generator(
schema_source=settings.schema_source,
convert_to_snake_case=settings.convert_to_snake_case,
include_all_inputs=settings.include_all_inputs,
include_all_enums=settings.include_all_enums,
base_model_file_path=BASE_MODEL_FILE_PATH.as_posix(),
base_model_import=BASE_MODEL_IMPORT,
upload_import=UPLOAD_IMPORT,
Expand Down
3 changes: 3 additions & 0 deletions ariadne_codegen/client_generators/result_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def get_unpacked_fragments(self) -> Set[str]:
def get_fragments_used_as_mixins(self) -> Set[str]:
return self._fragments_used_as_mixins

def get_used_enums(self) -> List[str]:
return self._used_enums

def _parse_type_definition(
self,
class_name: str,
Expand Down
1 change: 1 addition & 0 deletions ariadne_codegen/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class ClientSettings(BaseSettings):
include_comments: CommentsStrategy = field(default=CommentsStrategy.STABLE)
convert_to_snake_case: bool = True
include_all_inputs: bool = True
include_all_enums: bool = True
async_client: bool = True
opentelemetry_client: bool = False
files_to_include: List[str] = field(default_factory=list)
Expand Down
35 changes: 35 additions & 0 deletions tests/client_generators/test_enums_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast

import pytest
from graphql import GraphQLEnumType, GraphQLSchema, build_ast_schema, parse

from ariadne_codegen.client_generators.constants import ENUM_CLASS, ENUM_MODULE
Expand Down Expand Up @@ -175,3 +176,37 @@ def test_generate_triggers_generate_enum_hook_for_every_definition(
assert call0_enum_type.name == "TestEnumAB"
assert isinstance(call1_enum_type, GraphQLEnumType)
assert call1_enum_type.name == "TestEnumCD"


@pytest.mark.parametrize(
"types_to_include, public_names",
[
(None, ["EnumA", "EnumB", "EnumC"]),
(["EnumA", "EnumC"], ["EnumA", "EnumC"]),
(["EnumA", "EnumA", "EnumA"], ["EnumA"]),
],
)
def test_generate_returns_module_with_filtered_classes(types_to_include, public_names):
schema_str = """
enum EnumA {
A1
A2
}
enum EnumB {
B1
B2
}
enum EnumC {
C1
C2
}
"""

generator = EnumsGenerator(schema=build_ast_schema(parse(schema_str)))
module = generator.generate(types_to_include=types_to_include)

class_names = [c.name for c in filter_class_defs(module)]
assert sorted(generator.get_generated_public_names()) == sorted(public_names)
assert sorted(class_names) == sorted(public_names)
52 changes: 0 additions & 52 deletions tests/main/clients/only_used_inputs/expected_client/client.py

This file was deleted.

Empty file.
11 changes: 0 additions & 11 deletions tests/main/clients/only_used_inputs/queries.graphql

This file was deleted.

Loading

0 comments on commit 46bd337

Please sign in to comment.