Skip to content

Commit

Permalink
Tests and canonicalization for ConfigWorkflow (#89)
Browse files Browse the repository at this point in the history
* add doctest infra
* add tests for `ConfigWorkflow` (creation from yaml / API)
* add `CanonicalWorkflow` with 
   - additional validation and information from outside yaml file
   - conversion function from `ConfigWorkflow`
   - unit tests
* update `sirocco.workflow` to use `CanonicalWorkflow`
* basic `workflow.Workflow` creation from `CanonicalWorkflow` test
  • Loading branch information
DropD authored Jan 21, 2025
1 parent caced37 commit 2c0ee10
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 31 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Changelog = "https://github.com/C2SM/Sirocco/blob/main/CHANGELOG.md"
[tool.pytest.ini_options]
# Configuration for [pytest](https://docs.pytest.org)
addopts = "--pdbcls=IPython.terminal.debugger:TerminalPdb"
norecursedirs = "tests/cases"

[tool.coverage.run]
# Configuration of [coverage.py](https://coverage.readthedocs.io)
Expand Down Expand Up @@ -77,6 +78,8 @@ path = "src/sirocco/__init__.py"
extra-dependencies = [
"ipdb"
]
default-args = []
extra-args = ["--doctest-modules"]

[[tool.hatch.envs.hatch-test.matrix]]
python = ["3.12"]
Expand Down
15 changes: 9 additions & 6 deletions src/sirocco/core/workflow.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations

from itertools import product
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Self

from sirocco.core import _tasks # noqa [F401]
from sirocco.core.graph_items import Cycle, Data, Store, Task
from sirocco.parsing._yaml_data_models import (
ConfigWorkflow,
CanonicalWorkflow,
load_workflow_config,
)

Expand All @@ -21,7 +20,7 @@
class Workflow:
"""Internal representation of a workflow"""

def __init__(self, workflow_config: ConfigWorkflow) -> None:
def __init__(self, workflow_config: CanonicalWorkflow) -> None:
self.name: str = workflow_config.name
self.config_rootdir: Path = workflow_config.rootdir
self.tasks: Store = Store()
Expand Down Expand Up @@ -70,7 +69,11 @@ def iter_coordinates(param_refs: list, date: datetime | None = None) -> Iterator
self.tasks.add(task)
cycle_tasks.append(task)
self.cycles.add(
Cycle(name=cycle_name, tasks=cycle_tasks, coordinates={} if date is None else {"date": date})
Cycle(
name=cycle_name,
tasks=cycle_tasks,
coordinates={} if date is None else {"date": date},
)
)

# 4 - Link wait on tasks
Expand All @@ -85,5 +88,5 @@ def cycle_dates(cycle_config: ConfigCycle) -> Iterator[datetime]:
yield date

@classmethod
def from_yaml(cls, config_path: str):
def from_yaml(cls: type[Self], config_path: str) -> Self:
return cls(load_workflow_config(config_path))
115 changes: 96 additions & 19 deletions src/sirocco/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
from __future__ import annotations

import itertools
import time
import typing
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Annotated, Any, ClassVar, Literal

from isoduration import parse_duration
from isoduration.types import Duration # pydantic needs type # noqa: TCH002
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, field_validator, model_validator
from pydantic import (
AfterValidator,
BaseModel,
ConfigDict,
Discriminator,
Field,
Tag,
field_validator,
model_validator,
)

from sirocco.parsing._utils import TimeUtils

Expand Down Expand Up @@ -473,7 +484,11 @@ class ConfigData(BaseModel):
generated: list[ConfigGeneratedData] = []


def get_plugin_from_named_base_model(data: dict) -> str:
def get_plugin_from_named_base_model(
data: dict | ConfigRootTask | ConfigShellTask | ConfigIconTask,
) -> str:
if isinstance(data, (ConfigRootTask, ConfigShellTask, ConfigIconTask)):
return data.plugin
name_and_specs = _NamedBaseModel.merge_name_and_specs(data)
if name_and_specs.get("name", None) == "ROOT":
return ConfigRootTask.plugin
Expand All @@ -493,14 +508,44 @@ def get_plugin_from_named_base_model(data: dict) -> str:


class ConfigWorkflow(BaseModel):
"""
The root of the configuration tree.
Examples:
minimal yaml to generate:
>>> import textwrap
>>> import pydantic_yaml
>>> config = textwrap.dedent(
... '''
... cycles:
... - minimal_cycle:
... tasks:
... - task_a:
... tasks:
... - task_b:
... plugin: shell
... data:
... available:
... - foo:
... generated:
... - bar:
... '''
... )
>>> wf = pydantic_yaml.parse_yaml_raw_as(ConfigWorkflow, config)
minimum programmatically created instance
>>> empty_wf = ConfigWorkflow(cycles=[], tasks=[], data={})
"""

name: str | None = None
rootdir: Path | None = None
cycles: list[ConfigCycle]
tasks: list[ConfigTask]
data: ConfigData
parameters: dict[str, list] = {}
data_dict: dict = {}
task_dict: dict = {}

@field_validator("parameters", mode="before")
@classmethod
Expand All @@ -515,19 +560,9 @@ def check_parameters_lists(cls, data) -> dict[str, list]:
raise TypeError(msg)
return data

@model_validator(mode="after")
def build_internal_dicts(self) -> ConfigWorkflow:
self.data_dict = {data.name: data for data in self.data.available} | {
data.name: data for data in self.data.generated
}
self.task_dict = {task.name: task for task in self.tasks}
return self

@model_validator(mode="after")
def check_parameters(self) -> ConfigWorkflow:
task_data_list = self.tasks + self.data.generated
if self.data.available:
task_data_list.extend(self.data.available)
task_data_list = itertools.chain(self.tasks, self.data.generated, self.data.available)
for item in task_data_list:
for param_name in item.parameters:
if param_name not in self.parameters:
Expand All @@ -536,7 +571,48 @@ def check_parameters(self) -> ConfigWorkflow:
return self


def load_workflow_config(workflow_config: str) -> ConfigWorkflow:
ITEM_T = typing.TypeVar("ITEM_T")


def list_not_empty(value: list[ITEM_T]) -> list[ITEM_T]:
if len(value) < 1:
msg = "At least one element is required."
raise ValueError(msg)
return value


class CanonicalWorkflow(BaseModel):
name: str
rootdir: Path
cycles: Annotated[list[ConfigCycle], AfterValidator(list_not_empty)]
tasks: Annotated[list[ConfigTask], AfterValidator(list_not_empty)]
data: ConfigData
parameters: dict[str, list[Any]]

@property
def data_dict(self) -> dict[str, ConfigAvailableData | ConfigGeneratedData]:
return {data.name: data for data in itertools.chain(self.data.available, self.data.generated)}

@property
def task_dict(self) -> dict[str, ConfigTask]:
return {task.name: task for task in self.tasks}


def canonicalize_workflow(config_workflow: ConfigWorkflow, rootdir: Path) -> CanonicalWorkflow:
if not config_workflow.name:
msg = "Workflow name required for canonicalization."
raise ValueError(msg)
return CanonicalWorkflow(
name=config_workflow.name,
rootdir=rootdir,
cycles=config_workflow.cycles,
tasks=config_workflow.tasks,
data=config_workflow.data,
parameters=config_workflow.parameters,
)


def load_workflow_config(workflow_config: str) -> CanonicalWorkflow:
"""
Loads a python representation of a workflow config file.
Expand All @@ -554,6 +630,7 @@ def load_workflow_config(workflow_config: str) -> ConfigWorkflow:
if parsed_workflow.name is None:
parsed_workflow.name = config_path.stem

parsed_workflow.rootdir = config_path.resolve().parent
rootdir = config_path.resolve().parent

return parsed_workflow
return canonicalize_workflow(config_workflow=parsed_workflow, rootdir=rootdir)
# return parsed_workflow
11 changes: 5 additions & 6 deletions src/sirocco/pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def as_block(self, header: str, body: str) -> str:
Example:
>>> print(PrettyPrinter().as_block("header", "foo\nbar"))
>>> print(PrettyPrinter().as_block("header", "foo\\nbar"))
header:
foo
bar
Expand All @@ -50,7 +50,7 @@ def as_item(self, content: str) -> str:
- foo
>>> pp = PrettyPrinter()
>>> print(pp.as_item(pp.as_block("header", "multiple\nlines\nof text")))
>>> print(pp.as_item(pp.as_block("header", "multiple\\nlines\\nof text")))
- header:
multiple
lines
Expand Down Expand Up @@ -87,14 +87,13 @@ def format_basic(self, obj: core.GraphItem) -> str:
>>> from datetime import datetime
>>> print(
... PrettyPrinter().format_basic(
... Task(
... name=foo,
... core.Task(
... name="foo",
... coordinates={"date": datetime(1000, 1, 1).date()},
... workflow=None,
... )
... )
... )
foo [1000-01-01]
foo [date: 1000-01-01]
"""
name = obj.name
if obj.coordinates:
Expand Down
Empty file added tests/unit_tests/__init__.py
Empty file.
Empty file.
27 changes: 27 additions & 0 deletions tests/unit_tests/core/test_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pathlib

from sirocco import pretty_print
from sirocco.core import workflow
from sirocco.parsing import _yaml_data_models as models


def test_minimal_workflow():
minimal_config = models.CanonicalWorkflow(
name="minimal",
rootdir=pathlib.Path("minimal"),
cycles=[models.ConfigCycle(some_cycle={"tasks": []})],
tasks=[models.ConfigShellTask(some_task={"plugin": "shell"})],
data=models.ConfigData(
available=[models.ConfigAvailableData(foo={})],
generated=[models.ConfigGeneratedData(bar={})],
),
parameters={},
)

testee = workflow.Workflow(minimal_config)

pretty_print.PrettyPrinter().format(testee)

assert len(list(testee.tasks)) == 0
assert len(list(testee.cycles)) == 1
assert testee.data[("foo", {})].available
Empty file.
45 changes: 45 additions & 0 deletions tests/unit_tests/parsing/test_yaml_data_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pathlib
import textwrap

from sirocco.parsing import _yaml_data_models as models


def test_workflow_canonicalization():
config = models.ConfigWorkflow(
name="testee",
cycles=[models.ConfigCycle(minimal={"tasks": [models.ConfigCycleTask(a={})]})],
tasks=[{"some_task": {"plugin": "shell"}}],
data=models.ConfigData(
available=[models.ConfigAvailableData(foo={})],
generated=[models.ConfigGeneratedData(bar={})],
),
)

testee = models.canonicalize_workflow(config, rootdir=pathlib.Path("foo"))
assert testee.data_dict["foo"].name == "foo"
assert testee.data_dict["bar"].name == "bar"
assert testee.task_dict["some_task"].name == "some_task"


def test_load_workflow_config(tmp_path):
minimal_config = textwrap.dedent(
"""
cycles:
- minimal:
tasks:
- a:
tasks:
- b:
plugin: shell
data:
available:
- c:
generated:
- d:
"""
)
minimal = tmp_path / "minimal.yml"
minimal.write_text(minimal_config)
testee = models.load_workflow_config(str(minimal))
assert testee.name == "minimal"
assert testee.rootdir == tmp_path

0 comments on commit 2c0ee10

Please sign in to comment.