diff --git a/pyproject.toml b/pyproject.toml index aa085338..c4c2436f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ Changelog = "https://github.com/C2SM/Sirocco/blob/main/CHANGELOG.md" [tool.pytest.ini_options] # Configuration for [pytest](https://docs.pytest.org) addopts = "--pdbcls=IPython.terminal.debugger:TerminalPdb" +norecursedirs = "tests/cases" [tool.coverage.run] # Configuration of [coverage.py](https://coverage.readthedocs.io) @@ -77,6 +78,8 @@ path = "src/sirocco/__init__.py" extra-dependencies = [ "ipdb" ] +default-args = [] +extra-args = ["--doctest-modules"] [[tool.hatch.envs.hatch-test.matrix]] python = ["3.12"] diff --git a/src/sirocco/core/workflow.py b/src/sirocco/core/workflow.py index 82c9445c..537f826f 100644 --- a/src/sirocco/core/workflow.py +++ b/src/sirocco/core/workflow.py @@ -1,12 +1,11 @@ from __future__ import annotations from itertools import product -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self -from sirocco.core import _tasks # noqa [F401] from sirocco.core.graph_items import Cycle, Data, Store, Task from sirocco.parsing._yaml_data_models import ( - ConfigWorkflow, + CanonicalWorkflow, load_workflow_config, ) @@ -21,7 +20,7 @@ class Workflow: """Internal representation of a workflow""" - def __init__(self, workflow_config: ConfigWorkflow) -> None: + def __init__(self, workflow_config: CanonicalWorkflow) -> None: self.name: str = workflow_config.name self.config_rootdir: Path = workflow_config.rootdir self.tasks: Store = Store() @@ -70,7 +69,11 @@ def iter_coordinates(param_refs: list, date: datetime | None = None) -> Iterator self.tasks.add(task) cycle_tasks.append(task) self.cycles.add( - Cycle(name=cycle_name, tasks=cycle_tasks, coordinates={} if date is None else {"date": date}) + Cycle( + name=cycle_name, + tasks=cycle_tasks, + coordinates={} if date is None else {"date": date}, + ) ) # 4 - Link wait on tasks @@ -85,5 +88,5 @@ def cycle_dates(cycle_config: ConfigCycle) -> Iterator[datetime]: yield date @classmethod - def from_yaml(cls, config_path: str): + def from_yaml(cls: type[Self], config_path: str) -> Self: return cls(load_workflow_config(config_path)) diff --git a/src/sirocco/parsing/_yaml_data_models.py b/src/sirocco/parsing/_yaml_data_models.py index 1984b811..5ec12f40 100644 --- a/src/sirocco/parsing/_yaml_data_models.py +++ b/src/sirocco/parsing/_yaml_data_models.py @@ -1,6 +1,8 @@ from __future__ import annotations +import itertools import time +import typing from dataclasses import dataclass, field from datetime import datetime from pathlib import Path @@ -8,7 +10,16 @@ from isoduration import parse_duration from isoduration.types import Duration # pydantic needs type # noqa: TCH002 -from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, field_validator, model_validator +from pydantic import ( + AfterValidator, + BaseModel, + ConfigDict, + Discriminator, + Field, + Tag, + field_validator, + model_validator, +) from sirocco.parsing._utils import TimeUtils @@ -473,7 +484,11 @@ class ConfigData(BaseModel): generated: list[ConfigGeneratedData] = [] -def get_plugin_from_named_base_model(data: dict) -> str: +def get_plugin_from_named_base_model( + data: dict | ConfigRootTask | ConfigShellTask | ConfigIconTask, +) -> str: + if isinstance(data, (ConfigRootTask, ConfigShellTask, ConfigIconTask)): + return data.plugin name_and_specs = _NamedBaseModel.merge_name_and_specs(data) if name_and_specs.get("name", None) == "ROOT": return ConfigRootTask.plugin @@ -493,14 +508,44 @@ def get_plugin_from_named_base_model(data: dict) -> str: class ConfigWorkflow(BaseModel): + """ + The root of the configuration tree. + + Examples: + + minimal yaml to generate: + + >>> import textwrap + >>> import pydantic_yaml + >>> config = textwrap.dedent( + ... ''' + ... cycles: + ... - minimal_cycle: + ... tasks: + ... - task_a: + ... tasks: + ... - task_b: + ... plugin: shell + ... data: + ... available: + ... - foo: + ... generated: + ... - bar: + ... ''' + ... ) + >>> wf = pydantic_yaml.parse_yaml_raw_as(ConfigWorkflow, config) + + minimum programmatically created instance + + >>> empty_wf = ConfigWorkflow(cycles=[], tasks=[], data={}) + + """ + name: str | None = None - rootdir: Path | None = None cycles: list[ConfigCycle] tasks: list[ConfigTask] data: ConfigData parameters: dict[str, list] = {} - data_dict: dict = {} - task_dict: dict = {} @field_validator("parameters", mode="before") @classmethod @@ -515,19 +560,9 @@ def check_parameters_lists(cls, data) -> dict[str, list]: raise TypeError(msg) return data - @model_validator(mode="after") - def build_internal_dicts(self) -> ConfigWorkflow: - self.data_dict = {data.name: data for data in self.data.available} | { - data.name: data for data in self.data.generated - } - self.task_dict = {task.name: task for task in self.tasks} - return self - @model_validator(mode="after") def check_parameters(self) -> ConfigWorkflow: - task_data_list = self.tasks + self.data.generated - if self.data.available: - task_data_list.extend(self.data.available) + task_data_list = itertools.chain(self.tasks, self.data.generated, self.data.available) for item in task_data_list: for param_name in item.parameters: if param_name not in self.parameters: @@ -536,7 +571,48 @@ def check_parameters(self) -> ConfigWorkflow: return self -def load_workflow_config(workflow_config: str) -> ConfigWorkflow: +ITEM_T = typing.TypeVar("ITEM_T") + + +def list_not_empty(value: list[ITEM_T]) -> list[ITEM_T]: + if len(value) < 1: + msg = "At least one element is required." + raise ValueError(msg) + return value + + +class CanonicalWorkflow(BaseModel): + name: str + rootdir: Path + cycles: Annotated[list[ConfigCycle], AfterValidator(list_not_empty)] + tasks: Annotated[list[ConfigTask], AfterValidator(list_not_empty)] + data: ConfigData + parameters: dict[str, list[Any]] + + @property + def data_dict(self) -> dict[str, ConfigAvailableData | ConfigGeneratedData]: + return {data.name: data for data in itertools.chain(self.data.available, self.data.generated)} + + @property + def task_dict(self) -> dict[str, ConfigTask]: + return {task.name: task for task in self.tasks} + + +def canonicalize_workflow(config_workflow: ConfigWorkflow, rootdir: Path) -> CanonicalWorkflow: + if not config_workflow.name: + msg = "Workflow name required for canonicalization." + raise ValueError(msg) + return CanonicalWorkflow( + name=config_workflow.name, + rootdir=rootdir, + cycles=config_workflow.cycles, + tasks=config_workflow.tasks, + data=config_workflow.data, + parameters=config_workflow.parameters, + ) + + +def load_workflow_config(workflow_config: str) -> CanonicalWorkflow: """ Loads a python representation of a workflow config file. @@ -554,6 +630,7 @@ def load_workflow_config(workflow_config: str) -> ConfigWorkflow: if parsed_workflow.name is None: parsed_workflow.name = config_path.stem - parsed_workflow.rootdir = config_path.resolve().parent + rootdir = config_path.resolve().parent - return parsed_workflow + return canonicalize_workflow(config_workflow=parsed_workflow, rootdir=rootdir) + # return parsed_workflow diff --git a/src/sirocco/pretty_print.py b/src/sirocco/pretty_print.py index 57a04114..baaf7ac5 100644 --- a/src/sirocco/pretty_print.py +++ b/src/sirocco/pretty_print.py @@ -31,7 +31,7 @@ def as_block(self, header: str, body: str) -> str: Example: - >>> print(PrettyPrinter().as_block("header", "foo\nbar")) + >>> print(PrettyPrinter().as_block("header", "foo\\nbar")) header: foo bar @@ -50,7 +50,7 @@ def as_item(self, content: str) -> str: - foo >>> pp = PrettyPrinter() - >>> print(pp.as_item(pp.as_block("header", "multiple\nlines\nof text"))) + >>> print(pp.as_item(pp.as_block("header", "multiple\\nlines\\nof text"))) - header: multiple lines @@ -87,14 +87,13 @@ def format_basic(self, obj: core.GraphItem) -> str: >>> from datetime import datetime >>> print( ... PrettyPrinter().format_basic( - ... Task( - ... name=foo, + ... core.Task( + ... name="foo", ... coordinates={"date": datetime(1000, 1, 1).date()}, - ... workflow=None, ... ) ... ) ... ) - foo [1000-01-01] + foo [date: 1000-01-01] """ name = obj.name if obj.coordinates: diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/core/__init__.py b/tests/unit_tests/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/core/test_workflow.py b/tests/unit_tests/core/test_workflow.py new file mode 100644 index 00000000..188bb85e --- /dev/null +++ b/tests/unit_tests/core/test_workflow.py @@ -0,0 +1,27 @@ +import pathlib + +from sirocco import pretty_print +from sirocco.core import workflow +from sirocco.parsing import _yaml_data_models as models + + +def test_minimal_workflow(): + minimal_config = models.CanonicalWorkflow( + name="minimal", + rootdir=pathlib.Path("minimal"), + cycles=[models.ConfigCycle(some_cycle={"tasks": []})], + tasks=[models.ConfigShellTask(some_task={"plugin": "shell"})], + data=models.ConfigData( + available=[models.ConfigAvailableData(foo={})], + generated=[models.ConfigGeneratedData(bar={})], + ), + parameters={}, + ) + + testee = workflow.Workflow(minimal_config) + + pretty_print.PrettyPrinter().format(testee) + + assert len(list(testee.tasks)) == 0 + assert len(list(testee.cycles)) == 1 + assert testee.data[("foo", {})].available diff --git a/tests/unit_tests/parsing/__init__.py b/tests/unit_tests/parsing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/parsing/test_yaml_data_models.py b/tests/unit_tests/parsing/test_yaml_data_models.py new file mode 100644 index 00000000..6d827eba --- /dev/null +++ b/tests/unit_tests/parsing/test_yaml_data_models.py @@ -0,0 +1,45 @@ +import pathlib +import textwrap + +from sirocco.parsing import _yaml_data_models as models + + +def test_workflow_canonicalization(): + config = models.ConfigWorkflow( + name="testee", + cycles=[models.ConfigCycle(minimal={"tasks": [models.ConfigCycleTask(a={})]})], + tasks=[{"some_task": {"plugin": "shell"}}], + data=models.ConfigData( + available=[models.ConfigAvailableData(foo={})], + generated=[models.ConfigGeneratedData(bar={})], + ), + ) + + testee = models.canonicalize_workflow(config, rootdir=pathlib.Path("foo")) + assert testee.data_dict["foo"].name == "foo" + assert testee.data_dict["bar"].name == "bar" + assert testee.task_dict["some_task"].name == "some_task" + + +def test_load_workflow_config(tmp_path): + minimal_config = textwrap.dedent( + """ + cycles: + - minimal: + tasks: + - a: + tasks: + - b: + plugin: shell + data: + available: + - c: + generated: + - d: + """ + ) + minimal = tmp_path / "minimal.yml" + minimal.write_text(minimal_config) + testee = models.load_workflow_config(str(minimal)) + assert testee.name == "minimal" + assert testee.rootdir == tmp_path