Skip to content

Commit

Permalink
[components] Refactor component load behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed Jan 16, 2025
1 parent b402c39 commit 977ee0c
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
get_component_type_name,
)
from dagster_components.core.component_defs_builder import (
component_type_from_yaml_decl,
load_components_from_context,
path_to_decl_node,
resolve_decl_node_to_yaml_decls,
)
Expand Down Expand Up @@ -186,11 +184,9 @@ def check_component_command(ctx: click.Context, paths: Sequence[str]) -> None:
templated_value_resolver=TemplatedValueResolver.default(),
)
try:
load_components_from_context(clc)
decl_node.load(clc)
except ValidationError as e:
component_type = component_type_from_yaml_decl(
context.component_registry, yaml_decl
)
component_type = yaml_decl.get_component_type(context.component_registry)
validation_errors.append((component_type, e))

if validation_errors:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from dagster_components.core.schema.resolver import TemplatedValueResolver


class ComponentDeclNode: ...
class ComponentDeclNode(ABC):
@abstractmethod
def load(self, context: "ComponentLoadContext") -> Sequence["Component"]: ...


class Component(ABC):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any, Optional, TypeVar, Union
Expand All @@ -12,7 +13,15 @@
from dagster._utils.yaml_utils import parse_yaml_with_source_positions
from pydantic import BaseModel, TypeAdapter

from dagster_components.core.component import ComponentDeclNode, ComponentLoadContext
from dagster_components.cli.check import get_component_type_name
from dagster_components.core.component import (
Component,
ComponentDeclNode,
ComponentLoadContext,
ComponentTypeRegistry,
is_registered_component_type,
)
from dagster_components.core.component_defs_builder import load_module_from_path


class ComponentFileModel(BaseModel):
Expand All @@ -30,7 +39,16 @@ class YamlComponentDecl(ComponentDeclNode):
source_position_tree: Optional[SourcePositionTree] = None

@staticmethod
def from_path(component_file_path: Path) -> "YamlComponentDecl":
def component_file_path(path: Path) -> Path:
return path / "component.yaml"

@staticmethod
def exists_at(path: Path) -> bool:
return YamlComponentDecl.component_file_path(path).exists()

@staticmethod
def from_path(path: Path) -> "YamlComponentDecl":
component_file_path = YamlComponentDecl.component_file_path(path)
parsed = parse_yaml_with_source_positions(
component_file_path.read_text(), str(component_file_path)
)
Expand All @@ -39,11 +57,36 @@ def from_path(component_file_path: Path) -> "YamlComponentDecl":
)

return YamlComponentDecl(
path=component_file_path.parent,
path=path,
component_file_model=obj,
source_position_tree=parsed.source_position_tree,
)

def get_component_type(self, registry: ComponentTypeRegistry) -> type[Component]:
parsed_defs = self.component_file_model
if parsed_defs.type.startswith("."):
component_registry_key = parsed_defs.type[1:]

# Iterate over Python files in the folder
for py_file in self.path.glob("*.py"):
module_name = py_file.stem

module = load_module_from_path(module_name, self.path / f"{module_name}.py")

for _name, obj in inspect.getmembers(module, inspect.isclass):
assert isinstance(obj, type)
if (
is_registered_component_type(obj)
and get_component_type_name(obj) == component_registry_key
):
return obj

raise Exception(
f"Could not find component type {component_registry_key} in {self.path}"
)

return registry.get(parsed_defs.type)

def get_params(self, context: ComponentLoadContext, params_schema: type[T]) -> T:
with pushd(str(self.path)):
raw_params = self.component_file_model.params
Expand All @@ -60,12 +103,26 @@ def get_params(self, context: ComponentLoadContext, params_schema: type[T]) -> T
else:
return TypeAdapter(params_schema).validate_python(preprocessed_params)

def load(self, context: ComponentLoadContext) -> Sequence[Component]:
component_type = self.get_component_type(context.registry)
component_schema = component_type.get_schema()
context = context.with_rendering_scope(component_type.get_additional_scope())
loaded_params = self.get_params(context, component_schema) if component_schema else None
return [component_type.load(loaded_params, context)]


@record
class ComponentFolder(ComponentDeclNode):
path: Path
sub_decls: Sequence[Union[YamlComponentDecl, "ComponentFolder"]]

def load(self, context: ComponentLoadContext) -> Sequence[Component]:
components = []
for sub_decl in self.sub_decls:
sub_context = context.for_decl_node(sub_decl)
components.extend(sub_decl.load(sub_context))
return components


def path_to_decl_node(path: Path) -> Optional[ComponentDeclNode]:
# right now, we only support two types of components, both of which are folders
Expand All @@ -75,10 +132,8 @@ def path_to_decl_node(path: Path) -> Optional[ComponentDeclNode]:
if not path.is_dir():
return None

component_path = path / "component.yaml"

if component_path.exists():
return YamlComponentDecl.from_path(component_path)
if YamlComponentDecl.exists_at(path):
return YamlComponentDecl.from_path(path)

subs = []
for subpath in path.iterdir():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib
import importlib.util
import inspect
from collections.abc import Mapping, Sequence
from pathlib import Path
from types import ModuleType
Expand All @@ -13,8 +12,6 @@
ComponentLoadContext,
ComponentTypeRegistry,
TemplatedValueResolver,
get_component_type_name,
is_registered_component_type,
)
from dagster_components.core.component_decl_builder import (
ComponentDeclNode,
Expand Down Expand Up @@ -53,57 +50,12 @@ def resolve_decl_node_to_yaml_decls(decl: ComponentDeclNode) -> list[YamlCompone
raise NotImplementedError(f"Unknown component type {decl}")


def load_components_from_context(context: ComponentLoadContext) -> Sequence[Component]:
node = context.decl_node
if isinstance(node, YamlComponentDecl):
component_type = component_type_from_yaml_decl(context.registry, node)
component_schema = component_type.get_schema()
context = context.with_rendering_scope(component_type.get_additional_scope())
loaded_params = node.get_params(context, component_schema) if component_schema else None
return [component_type.load(loaded_params, context)]
elif isinstance(node, ComponentFolder):
components = []
for sub_decl in node.sub_decls:
components.extend(load_components_from_context(context.for_decl_node(sub_decl)))
return components

raise NotImplementedError(f"Unknown component type {node}")


def component_type_from_yaml_decl(
registry: ComponentTypeRegistry, decl_node: YamlComponentDecl
) -> type[Component]:
parsed_defs = decl_node.component_file_model
if parsed_defs.type.startswith("."):
component_registry_key = parsed_defs.type[1:]

# Iterate over Python files in the folder
for py_file in decl_node.path.glob("*.py"):
module_name = py_file.stem

module = load_module_from_path(module_name, decl_node.path / f"{module_name}.py")

for _name, obj in inspect.getmembers(module, inspect.isclass):
assert isinstance(obj, type)
if (
is_registered_component_type(obj)
and get_component_type_name(obj) == component_registry_key
):
return obj

raise Exception(
f"Could not find component type {component_registry_key} in {decl_node.path}"
)

return registry.get(parsed_defs.type)


def build_components_from_component_folder(
context: ComponentLoadContext, path: Path
) -> Sequence[Component]:
component_folder = path_to_decl_node(path)
assert isinstance(component_folder, ComponentFolder)
return load_components_from_context(context.for_decl_node(component_folder))
return component_folder.load(context.for_decl_node(component_folder))


def build_defs_from_component_path(
Expand All @@ -122,7 +74,7 @@ def build_defs_from_component_path(
decl_node=decl_node,
templated_value_resolver=TemplatedValueResolver.default(),
)
components = load_components_from_context(context)
components = decl_node.load(context)
return defs_from_components(resources=resources, context=context, components=components)


Expand Down

0 comments on commit 977ee0c

Please sign in to comment.