Skip to content

Commit

Permalink
import fix
Browse files Browse the repository at this point in the history
  • Loading branch information
benpankow committed Jan 22, 2025
1 parent c651b5b commit c5477bc
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
type: dagster_components.test.simple_asset

params:
asset_key: "test"
value: {}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ class MyComponentSchema(ComponentSchemaBaseModel):
@component_type
class MyComponent(Component):
name = "my_component"
params_schema = MyComponentSchema

@classmethod
def get_schema(cls):
return MyComponentSchema

@classmethod
def load(cls, context: ComponentLoadContext) -> Self:
context.load_params(cls.params_schema)
context.load_params(MyComponentSchema)
return cls()

def build_defs(self, context: ComponentLoadContext) -> Definitions:
Expand All @@ -40,11 +43,14 @@ class MyNestedComponentSchema(ComponentSchemaBaseModel):
@component_type
class MyNestedComponent(Component):
name = "my_nested_component"
params_schema = MyNestedComponentSchema

@classmethod
def get_schema(cls):
return MyNestedComponentSchema

@classmethod
def load(cls, context: ComponentLoadContext) -> Self:
context.load_params(cls.params_schema)
context.load_params(MyNestedComponentSchema)
return cls()

def build_defs(self, context: ComponentLoadContext) -> Definitions:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ class ComponentValidationTestCase:
"""

component_path: str
component_type_filepath: Path
component_type_filepath: Optional[Path]
should_error: bool
validate_error_msg: Optional[Callable[[str], None]] = None
validate_error_msg_additional_cli: Optional[Callable[[str], None]] = None
check_error_msg: Optional[Callable[[str], None]] = None


def msg_includes_all_of(*substrings: str) -> Callable[[str], None]:
Expand All @@ -31,15 +31,22 @@ def _validate_error_msg(msg: str) -> None:
validate_error_msg=msg_includes_all_of(
"component.yaml:5", "params.an_int", "Input should be a valid integer"
),
check_error_msg=msg_includes_all_of(
"component.yaml:5",
"params.an_int",
"{} is not of type 'integer'",
),
)

BASIC_MISSING_VALUE = ComponentValidationTestCase(
component_path="validation/basic_component_missing_value",
component_type_filepath=Path(__file__).parent / "basic_components.py",
should_error=True,
validate_error_msg=msg_includes_all_of("component.yaml:3", "params.an_int", "required"),
validate_error_msg_additional_cli=msg_includes_all_of(
"Field `an_int` is required but not provided"
check_error_msg=msg_includes_all_of(
"component.yaml:3",
"params",
"'an_int' is a required property",
),
)

Expand All @@ -51,13 +58,30 @@ def _validate_error_msg(msg: str) -> None:
),
BASIC_INVALID_VALUE,
BASIC_MISSING_VALUE,
ComponentValidationTestCase(
component_path="validation/simple_asset_invalid_value",
component_type_filepath=None,
should_error=True,
validate_error_msg=msg_includes_all_of(
"component.yaml:5", "params.value", "Input should be a valid string"
),
check_error_msg=msg_includes_all_of(
"component.yaml:5",
"params.value",
"{} is not of type 'string'",
),
),
ComponentValidationTestCase(
component_path="validation/basic_component_extra_value",
component_type_filepath=Path(__file__).parent / "basic_components.py",
should_error=True,
validate_error_msg=msg_includes_all_of(
"component.yaml:7", "params.a_bool", "Extra inputs are not permitted"
),
check_error_msg=msg_includes_all_of(
"component.yaml:3",
"'a_bool' was unexpected",
),
),
ComponentValidationTestCase(
component_path="validation/nested_component_invalid_values",
Expand All @@ -71,6 +95,14 @@ def _validate_error_msg(msg: str) -> None:
"params.nested.baz.a_string",
"Input should be a valid string",
),
check_error_msg=msg_includes_all_of(
"component.yaml:7",
"params.nested.foo.an_int",
"{} is not of type 'integer'",
"component.yaml:12",
"params.nested.baz.a_string",
"{} is not of type 'string'",
),
),
ComponentValidationTestCase(
component_path="validation/nested_component_missing_values",
Expand All @@ -79,8 +111,13 @@ def _validate_error_msg(msg: str) -> None:
validate_error_msg=msg_includes_all_of(
"component.yaml:5", "params.nested.foo.an_int", "required"
),
validate_error_msg_additional_cli=msg_includes_all_of(
"Field `a_string` is required but not provided"
check_error_msg=msg_includes_all_of(
"component.yaml:5",
"params.nested.foo",
"'an_int' is a required property",
"component.yaml:10",
"params.nested.baz",
"'a_string' is a required property",
),
),
ComponentValidationTestCase(
Expand All @@ -94,6 +131,14 @@ def _validate_error_msg(msg: str) -> None:
"component.yaml:15",
"params.nested.baz.another_bool",
),
check_error_msg=msg_includes_all_of(
"component.yaml:5",
"params.nested.foo",
"'a_bool' was unexpected",
"component.yaml:12",
"params.nested.baz",
"'another_bool' was unexpected",
),
),
ComponentValidationTestCase(
component_path="validation/invalid_component_file_model",
Expand All @@ -107,5 +152,13 @@ def _validate_error_msg(msg: str) -> None:
"params",
"Input should be an object",
),
check_error_msg=msg_includes_all_of(
"component.yaml:1",
"type",
"{} is not of type 'string'",
"component.yaml:3",
"params",
"'asdfasdf' is not of type 'object'",
),
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def test_validation_cli(test_case: ComponentValidationTestCase) -> None:
assert test_case.validate_error_msg
test_case.validate_error_msg(str(result.stdout))

if test_case.validate_error_msg_additional_cli:
test_case.validate_error_msg_additional_cli(str(result.stdout))
if test_case.check_error_msg:
test_case.check_error_msg(str(result.stdout))
else:
assert result.exit_code == 0, str(result.stdout)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Optional

from dagster._core.definitions.definitions_class import Definitions

Expand All @@ -10,7 +11,7 @@


def load_test_component_defs_inject_component(
src_path: str, local_component_defn_to_inject: Path
src_path: str, local_component_defn_to_inject: Optional[Path]
) -> Definitions:
"""Loads a component from a test component project, making the provided local component defn
available in that component's __init__.py.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,20 @@ def format_indented_error_msg(col: int, msg: str) -> str:


def error_dict_to_formatted_error(
component_name: str, error_details: ValidationError, source_position_tree: SourcePositionTree
component_name: Optional[str],
error_details: ValidationError,
source_position_tree: SourcePositionTree,
prefix: Sequence[str] = (),
) -> str:
source_position, source_position_path = source_position_tree.lookup_closest_and_path(
["params", *error_details.absolute_path], trace=None
[*prefix, *error_details.absolute_path], trace=None
)

# Retrieves dotted path representation of the location of the error in the YAML file, e.g.
# params.nested.foo.an_int
location = ".".join(str(part) for part in error_details.absolute_path).split(" at ")[0]
location = ".".join([*prefix, *[str(part) for part in error_details.absolute_path]]).split(
" at "
)[0]

# Find the first source position that has a different start line than the current source position
# This is e.g. the parent json key of the current source position
Expand Down
34 changes: 30 additions & 4 deletions python_modules/libraries/dagster-dg/dagster_dg/cli/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,18 @@ def component_list_command(context: click.Context, **global_options: object) ->
# ########################
from dagster._utils.yaml_utils import parse_yaml_with_source_positions

COMPONENT_FILE_SCHEMA = {
"type": "object",
"properties": {
"type": {"type": "string"},
"params": {"type": "object"},
},
}


def _is_local_component(component_name: str) -> bool:
return component_name.startswith(".")


@component_group.command(name="check", cls=DgClickCommand)
@click.argument("paths", nargs=-1, type=click.Path(exists=True))
Expand All @@ -260,12 +272,14 @@ def component_check_command(
**global_options: object,
) -> None:
"""Check component files against their schemas, showing validation errors."""
top_level_component_validator = Draft202012Validator(schema=COMPONENT_FILE_SCHEMA)

cli_config = normalize_cli_config(global_options, context)
dg_context = DgContext.from_config_file_discovery_and_cli_config(Path.cwd(), cli_config)

component_registry = RemoteComponentRegistry.from_dg_context(dg_context)

validation_errors: list[tuple[str, ValidationError]] = []
validation_errors: list[tuple[Optional[str], ValidationError]] = []

component_contents_by_dir = {}
local_components = set()
Expand All @@ -279,17 +293,28 @@ def component_check_command(
component_doc_tree = parse_yaml_with_source_positions(
text, filename=str(component_path)
)
component_contents_by_dir[component_dir] = component_doc_tree

# First, validate the top-level structure of the component file
# (type and params keys) before we try to validate the params themselves.
top_level_errs = list(
top_level_component_validator.iter_errors(component_doc_tree.value)
)
for err in top_level_errs:
validation_errors.append((None, err))
if top_level_errs:
continue

component_contents_by_dir[component_dir] = component_doc_tree
component_name = component_doc_tree.value.get("type")
if component_name.startswith("."):
if _is_local_component(component_name):
local_components.add(component_dir)

# Fetch the local component types, if we need any local components
local_component_types = LocalComponentTypes.from_dg_context(dg_context, list(local_components))

for component_dir, component_doc_tree in component_contents_by_dir.items():
component_name = component_doc_tree.value.get("type")
if component_name.startswith("."):
if _is_local_component(component_name):
json_schema = (
local_component_types.get(component_dir, component_name).component_params_schema
or {}
Expand All @@ -308,6 +333,7 @@ def component_check_command(
component_name,
error,
source_position_tree=component_doc_tree.source_position_tree,
prefix=["params"] if component_name else [],
)
)
context.exit(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
COMPONENT_VALIDATION_TEST_CASES,
ComponentValidationTestCase,
)
from dagster_components_tests.integration_tests.validation_tests.utils import (
create_code_location_from_components,
)
from dagster_components_tests.utils import create_code_location_from_components
from dagster_dg.utils import ensure_dagster_dg_tests_import
from dagster_dg_tests.utils import ProxyRunner

Expand Down Expand Up @@ -36,11 +34,9 @@ def test_validation_cli(test_case: ComponentValidationTestCase) -> None:
if test_case.should_error:
assert result.exit_code != 0, str(result.stdout)

assert test_case.validate_error_msg
test_case.validate_error_msg(str(result.stdout))
assert test_case.check_error_msg
test_case.check_error_msg(str(result.stdout))

if test_case.validate_error_msg_additional_cli:
test_case.validate_error_msg_additional_cli(str(result.stdout))
else:
assert result.exit_code == 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def invoke(self, *args: str):

# For some reason the context setting `max_content_width` is not respected when using the
# CliRunner, so we have to set it manually.
return self.original.invoke(dg_cli, all_args, terminal_width=DG_CLI_MAX_OUTPUT_WIDTH)
result = self.original.invoke(dg_cli, all_args, terminal_width=DG_CLI_MAX_OUTPUT_WIDTH)
if result.exception:
traceback.print_exception(*result.exc_info)
return result

@contextmanager
def isolated_filesystem(self) -> Iterator[None]:
Expand Down

0 comments on commit c5477bc

Please sign in to comment.