diff --git a/python_modules/libraries/dagster-components/dagster_components/cli/check.py b/python_modules/libraries/dagster-components/dagster_components/cli/check.py index a7ba12668062b..ed3881c2265d2 100644 --- a/python_modules/libraries/dagster-components/dagster_components/cli/check.py +++ b/python_modules/libraries/dagster-components/dagster_components/cli/check.py @@ -16,8 +16,6 @@ get_component_type_name, ) from dagster_components.core.component_defs_builder import ( - component_type_from_yaml_decl, - load_components_from_context, path_to_decl_node, resolve_decl_node_to_yaml_decls, ) @@ -186,11 +184,9 @@ def check_component_command(ctx: click.Context, paths: Sequence[str]) -> None: templated_value_resolver=TemplatedValueResolver.default(), ) try: - load_components_from_context(clc) + decl_node.load(clc) except ValidationError as e: - component_type = component_type_from_yaml_decl( - context.component_registry, yaml_decl - ) + component_type = yaml_decl.get_component_type(context.component_registry) validation_errors.append((component_type, e)) if validation_errors: diff --git a/python_modules/libraries/dagster-components/dagster_components/core/component.py b/python_modules/libraries/dagster-components/dagster_components/core/component.py index d2352cecb4a9a..ae3a0ff11f7d2 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/component.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/component.py @@ -28,7 +28,9 @@ from dagster_components.core.schema.resolver import TemplatedValueResolver -class ComponentDeclNode: ... +class ComponentDeclNode(ABC): + @abstractmethod + def load(self, context: "ComponentLoadContext") -> Sequence["Component"]: ... class Component(ABC): diff --git a/python_modules/libraries/dagster-components/dagster_components/core/component_decl_builder.py b/python_modules/libraries/dagster-components/dagster_components/core/component_decl_builder.py index 96c048c4fdd6e..9eadd37fd42b0 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/component_decl_builder.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/component_decl_builder.py @@ -1,3 +1,4 @@ +import inspect from collections.abc import Mapping, Sequence from pathlib import Path from typing import Any, Optional, TypeVar, Union @@ -12,7 +13,15 @@ from dagster._utils.yaml_utils import parse_yaml_with_source_positions from pydantic import BaseModel, TypeAdapter -from dagster_components.core.component import ComponentDeclNode, ComponentLoadContext +from dagster_components.cli.check import get_component_type_name +from dagster_components.core.component import ( + Component, + ComponentDeclNode, + ComponentLoadContext, + ComponentTypeRegistry, + is_registered_component_type, +) +from dagster_components.core.component_defs_builder import load_module_from_path class ComponentFileModel(BaseModel): @@ -30,7 +39,16 @@ class YamlComponentDecl(ComponentDeclNode): source_position_tree: Optional[SourcePositionTree] = None @staticmethod - def from_path(component_file_path: Path) -> "YamlComponentDecl": + def component_file_path(path: Path) -> Path: + return path / "component.yaml" + + @staticmethod + def exists_at(path: Path) -> bool: + return YamlComponentDecl.component_file_path(path).exists() + + @staticmethod + def from_path(path: Path) -> "YamlComponentDecl": + component_file_path = YamlComponentDecl.component_file_path(path) parsed = parse_yaml_with_source_positions( component_file_path.read_text(), str(component_file_path) ) @@ -39,11 +57,36 @@ def from_path(component_file_path: Path) -> "YamlComponentDecl": ) return YamlComponentDecl( - path=component_file_path.parent, + path=path, component_file_model=obj, source_position_tree=parsed.source_position_tree, ) + def get_component_type(self, registry: ComponentTypeRegistry) -> type[Component]: + parsed_defs = self.component_file_model + if parsed_defs.type.startswith("."): + component_registry_key = parsed_defs.type[1:] + + # Iterate over Python files in the folder + for py_file in self.path.glob("*.py"): + module_name = py_file.stem + + module = load_module_from_path(module_name, self.path / f"{module_name}.py") + + for _name, obj in inspect.getmembers(module, inspect.isclass): + assert isinstance(obj, type) + if ( + is_registered_component_type(obj) + and get_component_type_name(obj) == component_registry_key + ): + return obj + + raise Exception( + f"Could not find component type {component_registry_key} in {self.path}" + ) + + return registry.get(parsed_defs.type) + def get_params(self, context: ComponentLoadContext, params_schema: type[T]) -> T: with pushd(str(self.path)): raw_params = self.component_file_model.params @@ -60,12 +103,26 @@ def get_params(self, context: ComponentLoadContext, params_schema: type[T]) -> T else: return TypeAdapter(params_schema).validate_python(preprocessed_params) + def load(self, context: ComponentLoadContext) -> Sequence[Component]: + component_type = self.get_component_type(context.registry) + component_schema = component_type.get_schema() + context = context.with_rendering_scope(component_type.get_additional_scope()) + loaded_params = self.get_params(context, component_schema) if component_schema else None + return [component_type.load(loaded_params, context)] + @record class ComponentFolder(ComponentDeclNode): path: Path sub_decls: Sequence[Union[YamlComponentDecl, "ComponentFolder"]] + def load(self, context: ComponentLoadContext) -> Sequence[Component]: + components = [] + for sub_decl in self.sub_decls: + sub_context = context.for_decl_node(sub_decl) + components.extend(sub_decl.load(sub_context)) + return components + def path_to_decl_node(path: Path) -> Optional[ComponentDeclNode]: # right now, we only support two types of components, both of which are folders @@ -75,10 +132,8 @@ def path_to_decl_node(path: Path) -> Optional[ComponentDeclNode]: if not path.is_dir(): return None - component_path = path / "component.yaml" - - if component_path.exists(): - return YamlComponentDecl.from_path(component_path) + if YamlComponentDecl.exists_at(path): + return YamlComponentDecl.from_path(path) subs = [] for subpath in path.iterdir(): diff --git a/python_modules/libraries/dagster-components/dagster_components/core/component_defs_builder.py b/python_modules/libraries/dagster-components/dagster_components/core/component_defs_builder.py index 4ae5e868f15bf..153d046cc1e3a 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/component_defs_builder.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/component_defs_builder.py @@ -1,6 +1,5 @@ import importlib import importlib.util -import inspect from collections.abc import Mapping, Sequence from pathlib import Path from types import ModuleType @@ -13,8 +12,6 @@ ComponentLoadContext, ComponentTypeRegistry, TemplatedValueResolver, - get_component_type_name, - is_registered_component_type, ) from dagster_components.core.component_decl_builder import ( ComponentDeclNode, @@ -53,57 +50,12 @@ def resolve_decl_node_to_yaml_decls(decl: ComponentDeclNode) -> list[YamlCompone raise NotImplementedError(f"Unknown component type {decl}") -def load_components_from_context(context: ComponentLoadContext) -> Sequence[Component]: - node = context.decl_node - if isinstance(node, YamlComponentDecl): - component_type = component_type_from_yaml_decl(context.registry, node) - component_schema = component_type.get_schema() - context = context.with_rendering_scope(component_type.get_additional_scope()) - loaded_params = node.get_params(context, component_schema) if component_schema else None - return [component_type.load(loaded_params, context)] - elif isinstance(node, ComponentFolder): - components = [] - for sub_decl in node.sub_decls: - components.extend(load_components_from_context(context.for_decl_node(sub_decl))) - return components - - raise NotImplementedError(f"Unknown component type {node}") - - -def component_type_from_yaml_decl( - registry: ComponentTypeRegistry, decl_node: YamlComponentDecl -) -> type[Component]: - parsed_defs = decl_node.component_file_model - if parsed_defs.type.startswith("."): - component_registry_key = parsed_defs.type[1:] - - # Iterate over Python files in the folder - for py_file in decl_node.path.glob("*.py"): - module_name = py_file.stem - - module = load_module_from_path(module_name, decl_node.path / f"{module_name}.py") - - for _name, obj in inspect.getmembers(module, inspect.isclass): - assert isinstance(obj, type) - if ( - is_registered_component_type(obj) - and get_component_type_name(obj) == component_registry_key - ): - return obj - - raise Exception( - f"Could not find component type {component_registry_key} in {decl_node.path}" - ) - - return registry.get(parsed_defs.type) - - def build_components_from_component_folder( context: ComponentLoadContext, path: Path ) -> Sequence[Component]: component_folder = path_to_decl_node(path) assert isinstance(component_folder, ComponentFolder) - return load_components_from_context(context.for_decl_node(component_folder)) + return component_folder.load(context.for_decl_node(component_folder)) def build_defs_from_component_path( @@ -122,7 +74,7 @@ def build_defs_from_component_path( decl_node=decl_node, templated_value_resolver=TemplatedValueResolver.default(), ) - components = load_components_from_context(context) + components = decl_node.load(context) return defs_from_components(resources=resources, context=context, components=components)