diff --git a/python_modules/libraries/dagster-components/dagster_components/cli/__init__.py b/python_modules/libraries/dagster-components/dagster_components/cli/__init__.py index 46d78fead72b3..6a2c7d9773b49 100644 --- a/python_modules/libraries/dagster-components/dagster_components/cli/__init__.py +++ b/python_modules/libraries/dagster-components/dagster_components/cli/__init__.py @@ -1,7 +1,6 @@ import click from dagster.version import __version__ -from dagster_components.cli.check import check_cli from dagster_components.cli.list import list_cli from dagster_components.cli.scaffold import scaffold_cli from dagster_components.core.component import BUILTIN_MAIN_COMPONENT_ENTRY_POINT @@ -12,7 +11,6 @@ def create_dagster_components_cli(): commands = { "scaffold": scaffold_cli, "list": list_cli, - "check": check_cli, } @click.group( diff --git a/python_modules/libraries/dagster-components/dagster_components/cli/list.py b/python_modules/libraries/dagster-components/dagster_components/cli/list.py index 078f0567e81fe..5b71b766eb039 100644 --- a/python_modules/libraries/dagster-components/dagster_components/cli/list.py +++ b/python_modules/libraries/dagster-components/dagster_components/cli/list.py @@ -62,11 +62,8 @@ def _add_component_type_to_output( @list_cli.command(name="local-component-types") -@click.pass_context @click.argument("component_directories", nargs=-1, type=click.Path(exists=True)) -def list_local_component_types_command( - ctx: click.Context, component_directories: Sequence[str] -) -> None: +def list_local_component_types_command(component_directories: Sequence[str]) -> None: """List local Dagster components found in the specified directories.""" output: list = [] for component_directory in component_directories: diff --git a/python_modules/libraries/dagster-components/dagster_components/test/__init__.py b/python_modules/libraries/dagster-components/dagster_components/test/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/basic_components.py b/python_modules/libraries/dagster-components/dagster_components/test/basic_components.py similarity index 99% rename from python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/basic_components.py rename to python_modules/libraries/dagster-components/dagster_components/test/basic_components.py index a9b2b87eabc9d..2245fd3b64c0d 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/basic_components.py +++ b/python_modules/libraries/dagster-components/dagster_components/test/basic_components.py @@ -3,10 +3,11 @@ """ from dagster._core.definitions.definitions_class import Definitions +from typing_extensions import Self + from dagster_components import Component, component_type from dagster_components.core.component import ComponentLoadContext from dagster_components.core.schema.base import ComponentSchemaBaseModel -from typing_extensions import Self class MyComponentSchema(ComponentSchemaBaseModel): diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/test_cases.py b/python_modules/libraries/dagster-components/dagster_components/test/test_cases.py similarity index 64% rename from python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/test_cases.py rename to python_modules/libraries/dagster-components/dagster_components/test/test_cases.py index 76b0a58cdc22c..daf793ae34c6f 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/test_cases.py +++ b/python_modules/libraries/dagster-components/dagster_components/test/test_cases.py @@ -10,10 +10,10 @@ class ComponentValidationTestCase: """ component_path: str - component_type_filepath: Path + component_type_filepath: Optional[Path] should_error: bool validate_error_msg: Optional[Callable[[str], None]] = None - validate_error_msg_additional_cli: Optional[Callable[[str], None]] = None + check_error_msg: Optional[Callable[[str], None]] = None def msg_includes_all_of(*substrings: str) -> Callable[[str], None]: @@ -31,6 +31,11 @@ def _validate_error_msg(msg: str) -> None: validate_error_msg=msg_includes_all_of( "component.yaml:5", "params.an_int", "Input should be a valid integer" ), + check_error_msg=msg_includes_all_of( + "component.yaml:5", + "params.an_int", + "{} is not of type 'integer'", + ), ) BASIC_MISSING_VALUE = ComponentValidationTestCase( @@ -38,8 +43,10 @@ def _validate_error_msg(msg: str) -> None: component_type_filepath=Path(__file__).parent / "basic_components.py", should_error=True, validate_error_msg=msg_includes_all_of("component.yaml:3", "params.an_int", "required"), - validate_error_msg_additional_cli=msg_includes_all_of( - "Field `an_int` is required but not provided" + check_error_msg=msg_includes_all_of( + "component.yaml:3", + "params", + "'an_int' is a required property", ), ) @@ -51,6 +58,19 @@ def _validate_error_msg(msg: str) -> None: ), BASIC_INVALID_VALUE, BASIC_MISSING_VALUE, + ComponentValidationTestCase( + component_path="validation/simple_asset_invalid_value", + component_type_filepath=None, + should_error=True, + validate_error_msg=msg_includes_all_of( + "component.yaml:5", "params.value", "Input should be a valid string" + ), + check_error_msg=msg_includes_all_of( + "component.yaml:5", + "params.value", + "{} is not of type 'string'", + ), + ), ComponentValidationTestCase( component_path="validation/basic_component_extra_value", component_type_filepath=Path(__file__).parent / "basic_components.py", @@ -58,6 +78,10 @@ def _validate_error_msg(msg: str) -> None: validate_error_msg=msg_includes_all_of( "component.yaml:7", "params.a_bool", "Extra inputs are not permitted" ), + check_error_msg=msg_includes_all_of( + "component.yaml:3", + "'a_bool' was unexpected", + ), ), ComponentValidationTestCase( component_path="validation/nested_component_invalid_values", @@ -71,6 +95,14 @@ def _validate_error_msg(msg: str) -> None: "params.nested.baz.a_string", "Input should be a valid string", ), + check_error_msg=msg_includes_all_of( + "component.yaml:7", + "params.nested.foo.an_int", + "{} is not of type 'integer'", + "component.yaml:12", + "params.nested.baz.a_string", + "{} is not of type 'string'", + ), ), ComponentValidationTestCase( component_path="validation/nested_component_missing_values", @@ -79,8 +111,13 @@ def _validate_error_msg(msg: str) -> None: validate_error_msg=msg_includes_all_of( "component.yaml:5", "params.nested.foo.an_int", "required" ), - validate_error_msg_additional_cli=msg_includes_all_of( - "Field `a_string` is required but not provided" + check_error_msg=msg_includes_all_of( + "component.yaml:5", + "params.nested.foo", + "'an_int' is a required property", + "component.yaml:10", + "params.nested.baz", + "'a_string' is a required property", ), ), ComponentValidationTestCase( @@ -94,6 +131,14 @@ def _validate_error_msg(msg: str) -> None: "component.yaml:15", "params.nested.baz.another_bool", ), + check_error_msg=msg_includes_all_of( + "component.yaml:5", + "params.nested.foo", + "'a_bool' was unexpected", + "component.yaml:12", + "params.nested.baz", + "'another_bool' was unexpected", + ), ), ComponentValidationTestCase( component_path="validation/invalid_component_file_model", @@ -107,5 +152,13 @@ def _validate_error_msg(msg: str) -> None: "params", "Input should be an object", ), + check_error_msg=msg_includes_all_of( + "component.yaml:1", + "type", + "{} is not of type 'string'", + "component.yaml:3", + "params", + "'asdfasdf' is not of type 'object'", + ), ), ] diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/component_loader.py b/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/component_loader.py index 58424d093eaa0..ef5238eae861a 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/component_loader.py +++ b/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/component_loader.py @@ -20,14 +20,19 @@ def load_test_component_defs(name: str) -> Definitions: ) -def load_test_component_project_context() -> CodeLocationProjectContext: +def load_test_component_project_context(include_test: bool = False) -> CodeLocationProjectContext: + components = {} package_name = "dagster_components.lib" - dc_module = importlib.import_module(package_name) - components = {} - for component in get_registered_component_types_in_module(dc_module): - key = f"dagster_components.{get_component_type_name(component)}" - components[key] = component + packages = ["dagster_components.lib"] + ( + ["dagster_components.lib.test"] if include_test else [] + ) + for package_name in packages: + dc_module = importlib.import_module(package_name) + + for component in get_registered_component_types_in_module(dc_module): + key = f"dagster_components.{'test.' if package_name.endswith('test') else ''}{get_component_type_name(component)}" + components[key] = component return CodeLocationProjectContext( root_path=str(Path(__file__).parent), diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/components/validation/simple_asset_invalid_value/component.yaml b/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/components/validation/simple_asset_invalid_value/component.yaml new file mode 100644 index 0000000000000..59dfee90ded4a --- /dev/null +++ b/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/components/validation/simple_asset_invalid_value/component.yaml @@ -0,0 +1,5 @@ +type: dagster_components.test.simple_asset + +params: + asset_key: "test" + value: {} diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/test_component_validation.py b/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/test_component_validation.py index c7a931a0fc5f6..cc6129250ac1c 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/test_component_validation.py +++ b/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/test_component_validation.py @@ -1,22 +1,14 @@ -from pathlib import Path - import pytest -from click.testing import CliRunner -from dagster._core.test_utils import new_cwd -from dagster_components.cli import cli -from dagster_components.utils import ensure_dagster_components_tests_import -from pydantic import ValidationError - -from dagster_components_tests.integration_tests.validation_tests.test_cases import ( - BASIC_INVALID_VALUE, - BASIC_MISSING_VALUE, +from dagster_components.test.test_cases import ( COMPONENT_VALIDATION_TEST_CASES, ComponentValidationTestCase, ) +from dagster_components.utils import ensure_dagster_components_tests_import +from pydantic import ValidationError + from dagster_components_tests.integration_tests.validation_tests.utils import ( load_test_component_defs_inject_component, ) -from dagster_components_tests.utils import create_code_location_from_components ensure_dagster_components_tests_import() @@ -44,120 +36,3 @@ def test_validation_messages(test_case: ComponentValidationTestCase) -> None: str(test_case.component_path), test_case.component_type_filepath, ) - - -@pytest.mark.parametrize( - "test_case", - COMPONENT_VALIDATION_TEST_CASES, - ids=[str(case.component_path) for case in COMPONENT_VALIDATION_TEST_CASES], -) -def test_validation_cli(test_case: ComponentValidationTestCase) -> None: - """Tests that the check CLI prints rich error messages when attempting to - load components with errors. - """ - runner = CliRunner() - - with create_code_location_from_components( - test_case.component_path, local_component_defn_to_inject=test_case.component_type_filepath - ) as tmpdir: - with new_cwd(str(tmpdir)): - result = runner.invoke( - cli, - [ - "--builtin-component-lib", - "dagster_components.test", - "check", - "component", - ], - catch_exceptions=False, - ) - if test_case.should_error: - assert result.exit_code != 0, str(result.stdout) - - assert test_case.validate_error_msg - test_case.validate_error_msg(str(result.stdout)) - - if test_case.validate_error_msg_additional_cli: - test_case.validate_error_msg_additional_cli(str(result.stdout)) - else: - assert result.exit_code == 0, str(result.stdout) - - -@pytest.mark.parametrize( - "scope_check_run", - [True, False], -) -def test_validation_cli_multiple_components(scope_check_run: bool) -> None: - """Ensure that the check CLI can validate multiple components in a single code location, and - that error messages from all components are displayed. - - The parameter `scope_check_run` determines whether the check CLI is run pointing at both - components or none (defaulting to the entire workspace) - the output should be the same in - either case, this just tests that the CLI can handle multiple filters. - """ - runner = CliRunner() - - with create_code_location_from_components( - BASIC_MISSING_VALUE.component_path, - BASIC_INVALID_VALUE.component_path, - local_component_defn_to_inject=BASIC_MISSING_VALUE.component_type_filepath, - ) as tmpdir: - with new_cwd(str(tmpdir)): - result = runner.invoke( - cli, - [ - "--builtin-component-lib", - "dagster_components.test", - "check", - "component", - *( - [ - str( - Path("my_location") / "components" / "basic_component_missing_value" - ), - str( - Path("my_location") / "components" / "basic_component_invalid_value" - ), - ] - if scope_check_run - else [] - ), - ], - catch_exceptions=False, - ) - assert result.exit_code != 0, str(result.stdout) - - assert BASIC_INVALID_VALUE.validate_error_msg and BASIC_MISSING_VALUE.validate_error_msg - BASIC_INVALID_VALUE.validate_error_msg(str(result.stdout)) - BASIC_MISSING_VALUE.validate_error_msg(str(result.stdout)) - - -def test_validation_cli_multiple_components_filter() -> None: - """Ensure that the check CLI filters components to validate based on the provided paths.""" - runner = CliRunner() - - with create_code_location_from_components( - BASIC_MISSING_VALUE.component_path, - BASIC_INVALID_VALUE.component_path, - local_component_defn_to_inject=BASIC_MISSING_VALUE.component_type_filepath, - ) as tmpdir: - with new_cwd(str(tmpdir)): - result = runner.invoke( - cli, - [ - "--builtin-component-lib", - "dagster_components.test", - "check", - "component", - str(Path("my_location") / "components" / "basic_component_missing_value"), - ], - catch_exceptions=False, - ) - assert result.exit_code != 0, str(result.stdout) - - assert BASIC_INVALID_VALUE.validate_error_msg and BASIC_MISSING_VALUE.validate_error_msg - - BASIC_MISSING_VALUE.validate_error_msg(str(result.stdout)) - # We exclude the invalid value test case - with pytest.raises(AssertionError): - BASIC_INVALID_VALUE.validate_error_msg(str(result.stdout)) diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/utils.py b/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/utils.py index 144cada1649c9..3175ee13e32e2 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/utils.py +++ b/python_modules/libraries/dagster-components/dagster_components_tests/integration_tests/validation_tests/utils.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Optional from dagster._core.definitions.definitions_class import Definitions @@ -10,7 +11,7 @@ def load_test_component_defs_inject_component( - src_path: str, local_component_defn_to_inject: Path + src_path: str, local_component_defn_to_inject: Optional[Path] ) -> Definitions: """Loads a component from a test component project, making the provided local component defn available in that component's __init__.py. @@ -18,7 +19,8 @@ def load_test_component_defs_inject_component( with inject_component( src_path=src_path, local_component_defn_to_inject=local_component_defn_to_inject ) as tmpdir: - context = load_test_component_project_context() + context = load_test_component_project_context(include_test=True) + return build_defs_from_component_path( path=Path(tmpdir), registry=context.component_registry, diff --git a/python_modules/libraries/dagster-dg/dagster_dg/cli/check_utils.py b/python_modules/libraries/dagster-dg/dagster_dg/cli/check_utils.py new file mode 100644 index 0000000000000..c8fe9f4a459e3 --- /dev/null +++ b/python_modules/libraries/dagster-dg/dagster_dg/cli/check_utils.py @@ -0,0 +1,113 @@ +from collections.abc import Sequence +from typing import Optional + +import click +import typer +from jsonschema import ValidationError + +from dagster_dg.yaml_utils.source_position import SourcePositionTree + + +@click.group(name="check") +def check_cli(): + """Commands for checking components.""" + + +def prepend_lines_with_line_numbers( + lines_with_numbers: Sequence[tuple[Optional[int], str]], +) -> Sequence[str]: + """Prepend each line with a line number, right-justified to the maximum line number length. + + Args: + lines_with_numbers: A sequence of tuples, where the first element is the line number and the + second element is the line content. Some lines may have a `None` line number, which + will be rendered as an empty string, used for e.g. inserted error message lines. + """ + max_line_number_length = max([len(str(n)) for n, _ in lines_with_numbers]) + return [ + f"{(str(n) if n else '').rjust(max_line_number_length)} | {line.rstrip()}" + for n, line in lines_with_numbers + ] + + +def augment_inline_error_message(location: str, msg: str): + """Improves a subset of Pyright error messages by including location information.""" + last_location_part = location.split(".")[-1] + if msg == "Field required": + return f"Field `{last_location_part}` is required but not provided" + return msg + + +def format_indented_error_msg(col: int, msg: str) -> str: + """Format an error message with a caret pointing to the column where the error occurred.""" + return typer.style(" " * (col - 1) + f"^ {msg}", fg=typer.colors.YELLOW) + + +OFFSET_LINES_BEFORE = 2 +OFFSET_LINES_AFTER = 3 + + +def error_dict_to_formatted_error( + component_name: Optional[str], + error_details: ValidationError, + source_position_tree: SourcePositionTree, + prefix: Sequence[str] = (), +) -> str: + source_position, source_position_path = source_position_tree.lookup_closest_and_path( + [*prefix, *error_details.absolute_path], trace=None + ) + + # Retrieves dotted path representation of the location of the error in the YAML file, e.g. + # params.nested.foo.an_int + location = ".".join([*prefix, *[str(part) for part in error_details.absolute_path]]).split( + " at " + )[0] + + # Find the first source position that has a different start line than the current source position + # This is e.g. the parent json key of the current source position + preceding_source_position = next( + iter( + [ + value + for value in reversed(list(source_position_path)) + if value.start.line < source_position.start.line + ] + ), + source_position, + ) + with open(source_position.filename) as f: + lines = f.readlines() + lines_with_line_numbers = list(zip(range(1, len(lines) + 1), lines)) + + filtered_lines_with_line_numbers = ( + lines_with_line_numbers[ + max( + 0, preceding_source_position.start.line - OFFSET_LINES_BEFORE + ) : source_position.start.line + ] + + [ + ( + None, + format_indented_error_msg( + source_position.start.col, + augment_inline_error_message(location, error_details.message), + ), + ) + ] + + lines_with_line_numbers[ + source_position.start.line : source_position.end.line + OFFSET_LINES_AFTER + ] + ) + # Combine the filtered lines with the line numbers, and add empty lines before and after + lines_with_line_numbers = prepend_lines_with_line_numbers( + [(None, ""), *filtered_lines_with_line_numbers, (None, "")] + ) + code_snippet = "\n".join(lines_with_line_numbers) + + fmt_filename = ( + f"{source_position.filename}" + f":{typer.style(source_position.start.line, fg=typer.colors.GREEN)}" + ) + fmt_location = typer.style(location, fg=typer.colors.BRIGHT_WHITE) + fmt_name = typer.style(f"{component_name} " if component_name else "", fg=typer.colors.RED) + return f"{fmt_filename} - {fmt_name}{fmt_location} {error_details.message}\n{code_snippet}\n" diff --git a/python_modules/libraries/dagster-dg/dagster_dg/cli/component.py b/python_modules/libraries/dagster-dg/dagster_dg/cli/component.py index 1ceae6078127d..6300fb01bcf8a 100644 --- a/python_modules/libraries/dagster-dg/dagster_dg/cli/component.py +++ b/python_modules/libraries/dagster-dg/dagster_dg/cli/component.py @@ -1,12 +1,12 @@ -import sys from collections.abc import Mapping, Sequence from pathlib import Path -from subprocess import CalledProcessError -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import click from click.core import ParameterSource +from jsonschema import Draft202012Validator, ValidationError +from dagster_dg.cli.check_utils import error_dict_to_formatted_error from dagster_dg.cli.global_options import dg_global_options from dagster_dg.component import RemoteComponentRegistry, RemoteComponentType from dagster_dg.config import ( @@ -25,6 +25,10 @@ not_none, parse_json_option, ) +from dagster_dg.yaml_utils import parse_yaml_with_source_positions + +if TYPE_CHECKING: + from dagster_dg.yaml_utils.source_position import ValueAndSourcePositionTree @click.group(name="component", cls=DgClickGroup) @@ -71,7 +75,7 @@ def _define_commands(self, cli_context: click.Context) -> None: exit_with_error("This command must be run inside a Dagster code location directory.") registry = RemoteComponentRegistry.from_dg_context(dg_context) - for key, component_type in registry.items(): + for key, component_type in registry.global_items(): command = _create_component_scaffold_subcommand(key, component_type) self.add_command(command) @@ -180,7 +184,7 @@ def scaffold_component_command( exit_with_error("This command must be run inside a Dagster code location directory.") registry = RemoteComponentRegistry.from_dg_context(dg_context) - if not registry.has(component_key): + if not registry.has_global(component_key): exit_with_error(f"No component type `{component_key}` could be resolved.") elif dg_context.has_component(component_name): exit_with_error(f"A component instance named `{component_name}` already exists.") @@ -248,6 +252,18 @@ def component_list_command(context: click.Context, **global_options: object) -> # ##### CHECK # ######################## +COMPONENT_FILE_SCHEMA = { + "type": "object", + "properties": { + "type": {"type": "string"}, + "params": {"type": "object"}, + }, +} + + +def _is_local_component(component_name: str) -> bool: + return component_name.startswith(".") + @component_group.command(name="check", cls=DgClickCommand) @click.argument("paths", nargs=-1, type=click.Path(exists=True)) @@ -259,10 +275,72 @@ def component_check_command( **global_options: object, ) -> None: """Check component files against their schemas, showing validation errors.""" + resolved_paths = [Path(path).absolute() for path in paths] + top_level_component_validator = Draft202012Validator(schema=COMPONENT_FILE_SCHEMA) + cli_config = normalize_cli_config(global_options, context) dg_context = DgContext.from_config_file_discovery_and_cli_config(Path.cwd(), cli_config) - try: - dg_context.external_components_command(["check", "component", *paths]) - except CalledProcessError: - sys.exit(1) + validation_errors: list[tuple[Optional[str], ValidationError, ValueAndSourcePositionTree]] = [] + + component_contents_by_dir = {} + local_component_dirs = set() + for component_dir in ( + dg_context.root_path / dg_context.root_package_name / "components" + ).iterdir(): + if resolved_paths and not any( + path == component_dir or path in component_dir.parents for path in resolved_paths + ): + continue + + component_path = component_dir / "component.yaml" + + if component_path.exists(): + text = component_path.read_text() + component_doc_tree = parse_yaml_with_source_positions( + text, filename=str(component_path) + ) + + # First, validate the top-level structure of the component file + # (type and params keys) before we try to validate the params themselves. + top_level_errs = list( + top_level_component_validator.iter_errors(component_doc_tree.value) + ) + for err in top_level_errs: + validation_errors.append((None, err, component_doc_tree)) + if top_level_errs: + continue + + component_contents_by_dir[component_dir] = component_doc_tree + component_name = component_doc_tree.value.get("type") + if _is_local_component(component_name): + local_component_dirs.add(component_dir) + + # Fetch the local component types, if we need any local components + component_registry = RemoteComponentRegistry.from_dg_context( + dg_context, local_component_type_dirs=list(local_component_dirs) + ) + + for component_dir, component_doc_tree in component_contents_by_dir.items(): + component_name = component_doc_tree.value.get("type") + json_schema = ( + component_registry.get(component_dir, component_name).component_params_schema or {} + ) + + v = Draft202012Validator(json_schema) + for err in v.iter_errors(component_doc_tree.value["params"]): + validation_errors.append((component_name, err, component_doc_tree)) + + if validation_errors: + for component_name, error, component_doc_tree in validation_errors: + click.echo( + error_dict_to_formatted_error( + component_name, + error, + source_position_tree=component_doc_tree.source_position_tree, + prefix=["params"] if component_name else [], + ) + ) + context.exit(1) + else: + click.echo("All components validated successfully.") diff --git a/python_modules/libraries/dagster-dg/dagster_dg/cli/component_type.py b/python_modules/libraries/dagster-dg/dagster_dg/cli/component_type.py index 3475be7565441..280cc7591f905 100644 --- a/python_modules/libraries/dagster-dg/dagster_dg/cli/component_type.py +++ b/python_modules/libraries/dagster-dg/dagster_dg/cli/component_type.py @@ -42,7 +42,7 @@ def component_type_scaffold_command( exit_with_error("This command must be run inside a Dagster code location directory.") registry = RemoteComponentRegistry.from_dg_context(dg_context) full_component_name = f"{dg_context.root_package_name}.{name}" - if registry.has(full_component_name): + if registry.has_global(full_component_name): exit_with_error(f"A component type named `{name}` already exists.") scaffold_component_type(dg_context, name) @@ -66,10 +66,10 @@ def component_type_docs_command( cli_config = normalize_cli_config(global_options, context) dg_context = DgContext.from_config_file_discovery_and_cli_config(Path.cwd(), cli_config) registry = RemoteComponentRegistry.from_dg_context(dg_context) - if not registry.has(component_type): + if not registry.has_global(component_type): exit_with_error(f"No component type `{component_type}` could be resolved.") - render_markdown_in_browser(markdown_for_component_type(registry.get(component_type))) + render_markdown_in_browser(markdown_for_component_type(registry.get_global(component_type))) # ######################## @@ -96,14 +96,14 @@ def component_type_info_command( cli_config = normalize_cli_config(global_options, context) dg_context = DgContext.from_config_file_discovery_and_cli_config(Path.cwd(), cli_config) registry = RemoteComponentRegistry.from_dg_context(dg_context) - if not registry.has(component_type): + if not registry.has_global(component_type): exit_with_error(f"No component type `{component_type}` could be resolved.") elif sum([description, scaffold_params_schema, component_params_schema]) > 1: exit_with_error( "Only one of --description, --scaffold-params-schema, and --component-params-schema can be specified." ) - component_type_metadata = registry.get(component_type) + component_type_metadata = registry.get_global(component_type) if description: if component_type_metadata.description: @@ -152,8 +152,8 @@ def component_type_list(context: click.Context, **global_options: object) -> Non cli_config = normalize_cli_config(global_options, context) dg_context = DgContext.from_config_file_discovery_and_cli_config(Path.cwd(), cli_config) registry = RemoteComponentRegistry.from_dg_context(dg_context) - for key in sorted(registry.keys()): + for key in sorted(registry.global_keys()): click.echo(key) - component_type = registry.get(key) + component_type = registry.get_global(key) if component_type.summary: click.echo(f" {component_type.summary}") diff --git a/python_modules/libraries/dagster-dg/dagster_dg/component.py b/python_modules/libraries/dagster-dg/dagster_dg/component.py index ff8212252a8d9..f19219aa9c924 100644 --- a/python_modules/libraries/dagster-dg/dagster_dg/component.py +++ b/python_modules/libraries/dagster-dg/dagster_dg/component.py @@ -1,7 +1,9 @@ import copy import json -from collections.abc import Iterable, Mapping +from collections import defaultdict +from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING, Any, Optional from dagster_dg.utils import is_valid_json @@ -26,7 +28,12 @@ def key(self) -> str: class RemoteComponentRegistry: @classmethod - def from_dg_context(cls, dg_context: "DgContext") -> "RemoteComponentRegistry": + def from_dg_context( + cls, dg_context: "DgContext", local_component_type_dirs: Optional[Sequence[Path]] = None + ) -> "RemoteComponentRegistry": + """Fetches the set of available component types, including local component types for the + specified directories. Caches the result if possible. + """ if dg_context.use_dg_managed_environment: dg_context.ensure_uv_lock() @@ -43,33 +50,74 @@ def from_dg_context(cls, dg_context: "DgContext") -> "RemoteComponentRegistry": dg_context.cache.set(cache_key, raw_registry_data) registry_data = json.loads(raw_registry_data) - return cls.from_dict(registry_data) + + local_component_data = [] + if local_component_type_dirs: + # TODO: cache + + raw_local_component_data = dg_context.external_components_command( + [ + "list", + "local-component-types", + *[str(path) for path in local_component_type_dirs], + ] + ) + local_component_data = json.loads(raw_local_component_data) + + return cls.from_dict( + global_component_types=registry_data, local_component_types=local_component_data + ) @classmethod - def from_dict(cls, components: dict[str, Mapping[str, Any]]) -> "RemoteComponentRegistry": + def from_dict( + cls, + global_component_types: dict[str, Mapping[str, Any]], + local_component_types: list[dict[str, Any]], + ) -> "RemoteComponentRegistry": + components_by_path = defaultdict(dict) + for entry in local_component_types: + components_by_path[entry["directory"]][entry["key"]] = RemoteComponentType( + **entry["metadata"] + ) + return RemoteComponentRegistry( - {key: RemoteComponentType(**value) for key, value in components.items()} + {key: RemoteComponentType(**value) for key, value in global_component_types.items()}, + local_components=components_by_path, ) - def __init__(self, components: dict[str, RemoteComponentType]): + def __init__( + self, + components: dict[str, RemoteComponentType], + local_components: dict[str, dict[str, RemoteComponentType]], + ): self._components: dict[str, RemoteComponentType] = copy.copy(components) + self._local_components: dict[str, dict[str, RemoteComponentType]] = copy.copy( + local_components + ) @staticmethod def empty() -> "RemoteComponentRegistry": - return RemoteComponentRegistry({}) + return RemoteComponentRegistry({}, {}) - def has(self, name: str) -> bool: + def has_global(self, name: str) -> bool: return name in self._components - def get(self, name: str) -> RemoteComponentType: + def get(self, path: Path, key: str) -> RemoteComponentType: + """Resolves a component type within the scope of a given component directory.""" + if key in self._components: + return self._components[key] + + return self._local_components[str(path)][key] + + def get_global(self, name: str) -> RemoteComponentType: return self._components[name] - def keys(self) -> Iterable[str]: + def global_keys(self) -> Iterable[str]: return self._components.keys() - def items(self) -> Iterable[tuple[str, RemoteComponentType]]: - for key in sorted(self.keys()): - yield key, self.get(key) + def global_items(self) -> Iterable[tuple[str, RemoteComponentType]]: + for key in sorted(self.global_keys()): + yield key, self.get_global(key) def __repr__(self) -> str: return f"<RemoteComponentRegistry {list(self._components.keys())}>" diff --git a/python_modules/libraries/dagster-dg/dagster_dg_tests/cli_tests/test_check.py b/python_modules/libraries/dagster-dg/dagster_dg_tests/cli_tests/test_check.py new file mode 100644 index 0000000000000..8314cb3c614f2 --- /dev/null +++ b/python_modules/libraries/dagster-dg/dagster_dg_tests/cli_tests/test_check.py @@ -0,0 +1,173 @@ +import contextlib +import os +import shutil +import tempfile +from collections.abc import Iterator +from contextlib import contextmanager +from pathlib import Path +from typing import Optional + +import pytest +from dagster_components.test.test_cases import ( + BASIC_INVALID_VALUE, + BASIC_MISSING_VALUE, + COMPONENT_VALIDATION_TEST_CASES, + ComponentValidationTestCase, +) +from dagster_dg.utils import discover_git_root, ensure_dagster_dg_tests_import + +ensure_dagster_dg_tests_import() +from dagster_dg_tests.utils import ProxyRunner, clear_module_from_cache + +COMPONENT_INTEGRATION_TEST_DIR = ( + Path(__file__).parent.parent.parent.parent + / "dagster-components" + / "dagster_components_tests" + / "integration_tests" + / "components" +) + + +@contextmanager +def new_cwd(path: str) -> Iterator[None]: + old = os.getcwd() + try: + os.chdir(path) + yield + finally: + os.chdir(old) + + +@contextlib.contextmanager +def create_code_location_from_components( + runner: ProxyRunner, *src_paths: str, local_component_defn_to_inject: Optional[Path] = None +) -> Iterator[Path]: + """Scaffolds a code location with the given components in a temporary directory, + injecting the provided local component defn into each component's __init__.py. + """ + dagster_git_repo_dir = str(discover_git_root(Path(__file__))) + with tempfile.TemporaryDirectory() as tmpdir, new_cwd(tmpdir): + runner.invoke( + "code-location", + "scaffold", + "--use-editable-dagster", + dagster_git_repo_dir, + "my_location", + ) + + code_location_dir = Path(tmpdir) / "my_location" + assert code_location_dir.exists() + + (code_location_dir / "lib").mkdir(parents=True, exist_ok=True) + (code_location_dir / "lib" / "__init__.py").touch() + for src_path in src_paths: + component_name = src_path.split("/")[-1] + + components_dir = code_location_dir / "my_location" / "components" / component_name + components_dir.mkdir(parents=True, exist_ok=True) + + origin_path = COMPONENT_INTEGRATION_TEST_DIR / src_path + + shutil.copytree(origin_path, components_dir, dirs_exist_ok=True) + if local_component_defn_to_inject: + shutil.copy(local_component_defn_to_inject, components_dir / "__init__.py") + + with clear_module_from_cache("my_location"): + yield code_location_dir + + +@pytest.mark.parametrize( + "test_case", + COMPONENT_VALIDATION_TEST_CASES, + ids=[str(case.component_path) for case in COMPONENT_VALIDATION_TEST_CASES], +) +def test_validation_cli(test_case: ComponentValidationTestCase) -> None: + """Tests that the check CLI prints rich error messages when attempting to + load components with errors. + """ + with ( + ProxyRunner.test() as runner, + create_code_location_from_components( + runner, + test_case.component_path, + local_component_defn_to_inject=test_case.component_type_filepath, + ) as tmpdir, + ): + with new_cwd(str(tmpdir)): + result = runner.invoke("component", "check") + if test_case.should_error: + assert result.exit_code != 0, str(result.stdout) + + assert test_case.check_error_msg + test_case.check_error_msg(str(result.stdout)) + + else: + assert result.exit_code == 0 + + +@pytest.mark.parametrize( + "scope_check_run", + [True, False], +) +def test_validation_cli_multiple_components(scope_check_run: bool) -> None: + """Ensure that the check CLI can validate multiple components in a single code location, and + that error messages from all components are displayed. + + The parameter `scope_check_run` determines whether the check CLI is run pointing at both + components or none (defaulting to the entire workspace) - the output should be the same in + either case, this just tests that the CLI can handle multiple filters. + """ + with ( + ProxyRunner.test() as runner, + create_code_location_from_components( + runner, + BASIC_MISSING_VALUE.component_path, + BASIC_INVALID_VALUE.component_path, + local_component_defn_to_inject=BASIC_MISSING_VALUE.component_type_filepath, + ) as tmpdir, + ): + with new_cwd(str(tmpdir)): + result = runner.invoke( + "component", + "check", + *( + [ + str(Path("my_location") / "components" / "basic_component_missing_value"), + str(Path("my_location") / "components" / "basic_component_invalid_value"), + ] + if scope_check_run + else [] + ), + ) + assert result.exit_code != 0, str(result.stdout) + + assert BASIC_INVALID_VALUE.check_error_msg and BASIC_MISSING_VALUE.check_error_msg + BASIC_INVALID_VALUE.check_error_msg(str(result.stdout)) + BASIC_MISSING_VALUE.check_error_msg(str(result.stdout)) + + +def test_validation_cli_multiple_components_filter() -> None: + """Ensure that the check CLI filters components to validate based on the provided paths.""" + with ( + ProxyRunner.test() as runner, + create_code_location_from_components( + runner, + BASIC_MISSING_VALUE.component_path, + BASIC_INVALID_VALUE.component_path, + local_component_defn_to_inject=BASIC_MISSING_VALUE.component_type_filepath, + ) as tmpdir, + ): + with new_cwd(str(tmpdir)): + result = runner.invoke( + "component", + "check", + str(Path("my_location") / "components" / "basic_component_missing_value"), + ) + assert result.exit_code != 0, str(result.stdout) + + assert BASIC_INVALID_VALUE.check_error_msg and BASIC_MISSING_VALUE.check_error_msg + + BASIC_MISSING_VALUE.check_error_msg(str(result.stdout)) + # We exclude the invalid value test case + with pytest.raises(AssertionError): + BASIC_INVALID_VALUE.check_error_msg(str(result.stdout)) diff --git a/python_modules/libraries/dagster-dg/dagster_dg_tests/cli_tests/test_component_type_commands.py b/python_modules/libraries/dagster-dg/dagster_dg_tests/cli_tests/test_component_type_commands.py index 955e08f73f58e..4abef709d4c93 100644 --- a/python_modules/libraries/dagster-dg/dagster_dg_tests/cli_tests/test_component_type_commands.py +++ b/python_modules/libraries/dagster-dg/dagster_dg_tests/cli_tests/test_component_type_commands.py @@ -32,7 +32,7 @@ def test_component_type_scaffold_success(in_deployment: bool) -> None: assert Path("foo_bar/lib/baz.py").exists() dg_context = DgContext.from_config_file_discovery_and_cli_config(Path.cwd(), {}) registry = RemoteComponentRegistry.from_dg_context(dg_context) - assert registry.has("foo_bar.baz") + assert registry.has_global("foo_bar.baz") def test_component_type_scaffold_outside_code_location_fails() -> None: @@ -71,7 +71,7 @@ def test_component_type_scaffold_succeeds_non_default_component_lib_package() -> assert Path("foo_bar/_lib/baz.py").exists() dg_context = DgContext.from_config_file_discovery_and_cli_config(Path.cwd(), {}) registry = RemoteComponentRegistry.from_dg_context(dg_context) - assert registry.has("bar.baz") + assert registry.has_global("bar.baz") def test_component_type_scaffold_fails_components_lib_package_does_not_exist() -> None: diff --git a/python_modules/libraries/dagster-dg/setup.py b/python_modules/libraries/dagster-dg/setup.py index aa2a825251696..ac7f229285053 100644 --- a/python_modules/libraries/dagster-dg/setup.py +++ b/python_modules/libraries/dagster-dg/setup.py @@ -37,6 +37,7 @@ def get_version() -> str: "click>=8", "typing_extensions>=4.4.0,<5", "markdown", + "jsonschema", "PyYAML>=5.1", ], include_package_data=True, diff --git a/python_modules/libraries/dagster-dg/tox.ini b/python_modules/libraries/dagster-dg/tox.ini index 83c7d67719446..f6d7724c5208f 100644 --- a/python_modules/libraries/dagster-dg/tox.ini +++ b/python_modules/libraries/dagster-dg/tox.ini @@ -18,6 +18,7 @@ deps = allowlist_externals = /bin/bash uv + jsonschema commands = !windows: /bin/bash -c '! pip list --exclude-editable | grep -e dagster' pytest ./dagster_dg_tests -vv {posargs}