Skip to content

Commit

Permalink
spin
Browse files Browse the repository at this point in the history
  • Loading branch information
benpankow committed Jan 23, 2025
1 parent 7bf5ec2 commit a873ee3
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 59 deletions.
26 changes: 11 additions & 15 deletions python_modules/libraries/dagster-dg/dagster_dg/cli/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dagster_dg.cli.check_utils import error_dict_to_formatted_error
from dagster_dg.cli.global_options import dg_global_options
from dagster_dg.component import LocalComponentTypes, RemoteComponentRegistry, RemoteComponentType
from dagster_dg.component import RemoteComponentRegistry, RemoteComponentType
from dagster_dg.config import (
get_config_from_cli_context,
has_config_on_cli_context,
Expand Down Expand Up @@ -75,7 +75,7 @@ def _define_commands(self, cli_context: click.Context) -> None:
exit_with_error("This command must be run inside a Dagster code location directory.")

registry = RemoteComponentRegistry.from_dg_context(dg_context)
for key, component_type in registry.items():
for key, component_type in registry.global_items():
command = _create_component_scaffold_subcommand(key, component_type)
self.add_command(command)

Expand Down Expand Up @@ -184,7 +184,7 @@ def scaffold_component_command(
exit_with_error("This command must be run inside a Dagster code location directory.")

registry = RemoteComponentRegistry.from_dg_context(dg_context)
if not registry.has(component_key):
if not registry.has_global(component_key):
exit_with_error(f"No component type `{component_key}` could be resolved.")
elif dg_context.has_component(component_name):
exit_with_error(f"A component instance named `{component_name}` already exists.")
Expand Down Expand Up @@ -281,12 +281,10 @@ def component_check_command(
cli_config = normalize_cli_config(global_options, context)
dg_context = DgContext.from_config_file_discovery_and_cli_config(Path.cwd(), cli_config)

component_registry = RemoteComponentRegistry.from_dg_context(dg_context)

validation_errors: list[tuple[Optional[str], ValidationError, ValueAndSourcePositionTree]] = []

component_contents_by_dir = {}
local_components = set()
local_component_dirs = set()
for component_dir in (
dg_context.root_path / dg_context.root_package_name / "components"
).iterdir():
Expand Down Expand Up @@ -316,20 +314,18 @@ def component_check_command(
component_contents_by_dir[component_dir] = component_doc_tree
component_name = component_doc_tree.value.get("type")
if _is_local_component(component_name):
local_components.add(component_dir)
local_component_dirs.add(component_dir)

# Fetch the local component types, if we need any local components
local_component_types = LocalComponentTypes.from_dg_context(dg_context, list(local_components))
component_registry = RemoteComponentRegistry.from_dg_context(
dg_context, local_component_type_dirs=list(local_component_dirs)
)

for component_dir, component_doc_tree in component_contents_by_dir.items():
component_name = component_doc_tree.value.get("type")
if _is_local_component(component_name):
json_schema = (
local_component_types.get(component_dir, component_name).component_params_schema
or {}
)
else:
json_schema = component_registry.get(component_name).component_params_schema or {}
json_schema = (
component_registry.get(component_dir, component_name).component_params_schema or {}
)

v = Draft202012Validator(json_schema) # type: ignore
for err in v.iter_errors(component_doc_tree.value["params"]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def component_type_scaffold_command(
exit_with_error("This command must be run inside a Dagster code location directory.")
registry = RemoteComponentRegistry.from_dg_context(dg_context)
full_component_name = f"{dg_context.root_package_name}.{name}"
if registry.has(full_component_name):
if registry.has_global(full_component_name):
exit_with_error(f"A component type named `{name}` already exists.")

scaffold_component_type(dg_context, name)
Expand All @@ -66,7 +66,7 @@ def component_type_docs_command(
cli_config = normalize_cli_config(global_options, context)
dg_context = DgContext.from_config_file_discovery_and_cli_config(Path.cwd(), cli_config)
registry = RemoteComponentRegistry.from_dg_context(dg_context)
if not registry.has(component_type):
if not registry.has_global(component_type):
exit_with_error(f"No component type `{component_type}` could be resolved.")

render_markdown_in_browser(markdown_for_component_type(registry.get(component_type)))
Expand Down Expand Up @@ -96,14 +96,14 @@ def component_type_info_command(
cli_config = normalize_cli_config(global_options, context)
dg_context = DgContext.from_config_file_discovery_and_cli_config(Path.cwd(), cli_config)
registry = RemoteComponentRegistry.from_dg_context(dg_context)
if not registry.has(component_type):
if not registry.has_global(component_type):
exit_with_error(f"No component type `{component_type}` could be resolved.")
elif sum([description, scaffold_params_schema, component_params_schema]) > 1:
exit_with_error(
"Only one of --description, --scaffold-params-schema, and --component-params-schema can be specified."
)

component_type_metadata = registry.get(component_type)
component_type_metadata = registry.get_global(component_type)

if description:
if component_type_metadata.description:
Expand Down Expand Up @@ -152,8 +152,8 @@ def component_type_list(context: click.Context, **global_options: object) -> Non
cli_config = normalize_cli_config(global_options, context)
dg_context = DgContext.from_config_file_discovery_and_cli_config(Path.cwd(), cli_config)
registry = RemoteComponentRegistry.from_dg_context(dg_context)
for key in sorted(registry.keys()):
for key in sorted(registry.global_keys()):
click.echo(key)
component_type = registry.get(key)
component_type = registry.get_global(key)
if component_type.summary:
click.echo(f" {component_type.summary}")
95 changes: 59 additions & 36 deletions python_modules/libraries/dagster-dg/dagster_dg/component.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import json
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -25,33 +26,14 @@ def key(self) -> str:
return self.name


class LocalComponentTypes:
@classmethod
def from_dg_context(cls, dg_context: "DgContext", component_dirs: Sequence[Path]):
# TODO: cache

raw_local_component_data = dg_context.external_components_command(
["list", "local-component-types", *component_dirs]
)
local_component_data = json.loads(raw_local_component_data)

components_by_path = {str(path): {} for path in component_dirs}
for entry in local_component_data:
components_by_path[entry["directory"]][entry["key"]] = RemoteComponentType(
**entry["metadata"]
)
return cls(components_by_path)

def __init__(self, components_by_path: dict[str, dict[str, RemoteComponentType]]):
self._components_by_path = copy.copy(components_by_path)

def get(self, path: Path, key: str) -> RemoteComponentType:
return self._components_by_path[str(path)][key]


class RemoteComponentRegistry:
@classmethod
def from_dg_context(cls, dg_context: "DgContext") -> "RemoteComponentRegistry":
def from_dg_context(
cls, dg_context: "DgContext", local_component_type_dirs: Optional[Sequence[Path]] = None
) -> "RemoteComponentRegistry":
"""Fetches the set of available component types, including local component types for the
specified directories. Caches the result if possible.
"""
if dg_context.use_dg_managed_environment:
dg_context.ensure_uv_lock()

Expand All @@ -68,33 +50,74 @@ def from_dg_context(cls, dg_context: "DgContext") -> "RemoteComponentRegistry":
dg_context.cache.set(cache_key, raw_registry_data)

registry_data = json.loads(raw_registry_data)
return cls.from_dict(registry_data)

local_component_data = []
if local_component_type_dirs:
# TODO: cache

raw_local_component_data = dg_context.external_components_command(
[
"list",
"local-component-types",
*[str(path) for path in local_component_type_dirs],
]
)
local_component_data = json.loads(raw_local_component_data)

return cls.from_dict(
global_component_types=registry_data, local_component_types=local_component_data
)

@classmethod
def from_dict(cls, components: dict[str, Mapping[str, Any]]) -> "RemoteComponentRegistry":
def from_dict(
cls,
global_component_types: dict[str, Mapping[str, Any]],
local_component_types: list[dict[str, Any]],
) -> "RemoteComponentRegistry":
components_by_path = defaultdict(dict)
for entry in local_component_types:
components_by_path[entry["directory"]][entry["key"]] = RemoteComponentType(
**entry["metadata"]
)

return RemoteComponentRegistry(
{key: RemoteComponentType(**value) for key, value in components.items()}
{key: RemoteComponentType(**value) for key, value in global_component_types.items()},
local_components=components_by_path,
)

def __init__(self, components: dict[str, RemoteComponentType]):
def __init__(
self,
components: dict[str, RemoteComponentType],
local_components: dict[str, dict[str, RemoteComponentType]],
):
self._components: dict[str, RemoteComponentType] = copy.copy(components)
self._local_components: dict[str, dict[str, RemoteComponentType]] = copy.copy(
local_components
)

@staticmethod
def empty() -> "RemoteComponentRegistry":
return RemoteComponentRegistry({})
return RemoteComponentRegistry({}, {})

def has(self, name: str) -> bool:
def has_global(self, name: str) -> bool:
return name in self._components

def get(self, name: str) -> RemoteComponentType:
def get(self, path: Path, key: str) -> RemoteComponentType:
"""Resolves a component type within the scope of a given component directory."""
if key in self._components:
return self._components[key]

return self._local_components[str(path)][key]

def get_global(self, name: str) -> RemoteComponentType:
return self._components[name]

def keys(self) -> Iterable[str]:
def global_keys(self) -> Iterable[str]:
return self._components.keys()

def items(self) -> Iterable[tuple[str, RemoteComponentType]]:
for key in sorted(self.keys()):
yield key, self.get(key)
def global_items(self) -> Iterable[tuple[str, RemoteComponentType]]:
for key in sorted(self.global_keys()):
yield key, self.get_global(key)

def __repr__(self) -> str:
return f"<RemoteComponentRegistry {list(self._components.keys())}>"
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_component_type_scaffold_success(in_deployment: bool) -> None:
assert Path("foo_bar/lib/baz.py").exists()
dg_context = DgContext.from_config_file_discovery_and_cli_config(Path.cwd(), {})
registry = RemoteComponentRegistry.from_dg_context(dg_context)
assert registry.has("foo_bar.baz")
assert registry.has_global("foo_bar.baz")


def test_component_type_scaffold_outside_code_location_fails() -> None:
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_component_type_scaffold_succeeds_non_default_component_lib_package() ->
assert Path("foo_bar/_lib/baz.py").exists()
dg_context = DgContext.from_config_file_discovery_and_cli_config(Path.cwd(), {})
registry = RemoteComponentRegistry.from_dg_context(dg_context)
assert registry.has("bar.baz")
assert registry.has_global("bar.baz")


def test_component_type_scaffold_fails_components_lib_package_does_not_exist() -> None:
Expand Down

0 comments on commit a873ee3

Please sign in to comment.