Skip to content

Commit

Permalink
replace canonical data models with improved named base
Browse files Browse the repository at this point in the history
  • Loading branch information
DropD committed Jan 21, 2025
1 parent ad992d1 commit 1931b5a
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 110 deletions.
4 changes: 2 additions & 2 deletions src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeAlias

from sirocco.parsing._yaml_data_models import (
CanonicalAvailableData,
ConfigAvailableData,
ConfigBaseDataSpecs,
ConfigBaseTaskSpecs,
)
Expand Down Expand Up @@ -47,7 +47,7 @@ def from_config(cls, config: ConfigBaseData, coordinates: dict) -> Self:
name=config.name,
type=config.type,
src=config.src,
available=isinstance(config, CanonicalAvailableData),
available=isinstance(config, ConfigAvailableData),
coordinates=coordinates,
)

Expand Down
155 changes: 59 additions & 96 deletions src/sirocco/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,46 +26,64 @@


class _NamedBaseModel(BaseModel):
"""Base class for all classes with a key that specifies their name.
For example:
.. yaml
"""
Noninvasive base model for reading names from yaml keys *or* attributes.
- property_name:
property: true
Examples:
When parsing with this as parent class it is converted to
`{"name": "propery_name", "property": True}`.
>>> NeoNameBase(name="foo")
NeoNameBase(name='foo')
>>> NeoNameBase(foo={})
NeoNameBase(name='foo')
>>> import pydantic_yaml, textwrap
>>> pydantic_yaml.parse_yaml_raw_as(
... NeoNameBase,
... textwrap.dedent('''
... foo:
... '''),
... )
NeoNameBase(name='foo')
>>> pydantic_yaml.parse_yaml_raw_as(
... NeoNameBase,
... textwrap.dedent('''
... name: foo
... '''),
... )
NeoNameBase(name='foo')
"""

name: str

def __init__(self, /, **data):
super().__init__(**self.merge_name_and_specs(data))

@staticmethod
def merge_name_and_specs(data: dict) -> dict:
"""
Converts dict of form
`{my_name: {'spec_0': ..., ..., 'spec_n': ...}`
to
`{'name': my_name, 'spec_0': ..., ..., 'spec_n': ...}`
@model_validator(mode="before")
@classmethod
def reformat_named_object(cls, data: Any) -> Any:
return cls.extract_merge_name(data)

by copy.
"""
name_and_spec = {}
if len(data) != 1:
msg = f"Expected dict with one element of the form {{'name': specification}} but got {data}."
raise ValueError(msg)
name_and_spec["name"] = next(iter(data.keys()))
# if no specification specified e.g. "- my_name:"
if (spec := next(iter(data.values()))) is not None:
name_and_spec.update(spec)
return name_and_spec
@classmethod
def extract_merge_name(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
if len(data) == 1:
key, value = next(iter(data.items()))
match key:
case str():
match value:
case str() if key == "name":
pass
case dict() if "name" not in value:
data = value | {"name": key}
case None:
data = {"name": key}
case _:
msg = f"{cls.__name__} may only be used for named objects, not values (got {data})."
raise TypeError(msg)
case _:
msg = f"{cls.__name__} requires name to be a str (got {key})."
raise TypeError(msg)
return data


class _WhenBaseModel(BaseModel):
Expand Down Expand Up @@ -471,7 +489,7 @@ class ConfigBaseData(_NamedBaseModel, ConfigBaseDataSpecs):
from python:
>>> ConfigBaseData(foo={"type": "file", "src": "foo.txt"})
>>> ConfigBaseData(name="foo", type=DataType.FILE, src="foo.txt")
ConfigBaseData(type=<DataType.FILE: 'file'>, src='foo.txt', format=None, computer=None, name='foo', parameters=[])
"""

Expand Down Expand Up @@ -502,48 +520,6 @@ def invalid_field(cls, value: str | None) -> str | None:
return value


class CanonicalBaseData(BaseModel, ConfigBaseDataSpecs):
"""
Canonical data model.
Examples:
>>> CanonicalBaseData(name="foo", type=DataType.FILE, src="foo.txt")
CanonicalBaseData(type=<DataType.FILE: 'file'>, src='foo.txt', format=None, computer=None, name='foo', parameters=[])
"""

name: str
parameters: list[str] = []


class CanonicalAvailableData(CanonicalBaseData):
pass


class CanonicalGeneratedData(CanonicalBaseData):
pass


def canonicalize_available_data(value: ConfigAvailableData) -> CanonicalAvailableData:
return CanonicalAvailableData(
name=value.name,
type=value.type,
src=value.src,
format=value.format,
parameters=value.parameters,
)


def canonicalize_generated_data(value: ConfigGeneratedData) -> CanonicalGeneratedData:
return CanonicalGeneratedData(
name=value.name,
type=value.type,
src=value.src,
format=value.format,
parameters=value.parameters,
)


class ConfigData(BaseModel):
"""
To create the container of available and generated data
Expand All @@ -561,9 +537,9 @@ class ConfigData(BaseModel):
... type: "file"
... src: "foo.txt"
... generated:
... - bar:
... type: "file"
... src: "bar.txt"
... - name: "bar"
... type: "file"
... src: "bar.txt"
... '''
... )
>>> data = pydantic_yaml.parse_yaml_raw_as(ConfigData, snippet)
Expand All @@ -580,24 +556,12 @@ class ConfigData(BaseModel):
generated: list[ConfigGeneratedData] = []


class CanonicalData(BaseModel):
available: list[CanonicalAvailableData] = []
generated: list[CanonicalGeneratedData] = []


def canonicalize_data(value: ConfigData) -> CanonicalData:
return CanonicalData(
available=[canonicalize_available_data(i) for i in value.available],
generated=[canonicalize_generated_data(i) for i in value.generated],
)


def get_plugin_from_named_base_model(
data: dict | ConfigRootTask | ConfigShellTask | ConfigIconTask,
) -> str:
if isinstance(data, (ConfigRootTask, ConfigShellTask, ConfigIconTask)):
return data.plugin
name_and_specs = _NamedBaseModel.merge_name_and_specs(data)
name_and_specs = _NamedBaseModel.extract_merge_name(data)
if name_and_specs.get("name", None) == "ROOT":
return ConfigRootTask.plugin
plugin = name_and_specs.get("plugin", None)
Expand Down Expand Up @@ -698,11 +662,11 @@ class CanonicalWorkflow(BaseModel):
rootdir: Path
cycles: Annotated[list[ConfigCycle], AfterValidator(list_not_empty)]
tasks: Annotated[list[ConfigTask], AfterValidator(list_not_empty)]
data: CanonicalData
data: ConfigData
parameters: dict[str, list[Any]]

@property
def data_dict(self) -> dict[str, CanonicalAvailableData | CanonicalGeneratedData]:
def data_dict(self) -> dict[str, ConfigAvailableData | ConfigGeneratedData]:
return {data.name: data for data in itertools.chain(self.data.available, self.data.generated)}

@property
Expand All @@ -714,13 +678,12 @@ def canonicalize_workflow(value: ConfigWorkflow, rootdir: Path) -> CanonicalWork
if not value.name:
msg = "Workflow name required for canonicalization."
raise ValueError(msg)
canon_data = canonicalize_data(value.data)
return CanonicalWorkflow(
name=value.name,
rootdir=rootdir,
cycles=value.cycles,
tasks=value.tasks,
data=canon_data,
data=value.data,
parameters=value.parameters,
)

Expand Down
10 changes: 5 additions & 5 deletions tests/unit_tests/core/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ def test_minimal_workflow():
minimal_config = models.CanonicalWorkflow(
name="minimal",
rootdir=pathlib.Path("minimal"),
cycles=[models.ConfigCycle(some_cycle={"tasks": []})],
tasks=[models.ConfigShellTask(some_task={"plugin": "shell"})],
data=models.CanonicalData(
available=[models.CanonicalAvailableData(name="foo", type=models.DataType.FILE, src="foo.txt")],
generated=[models.CanonicalGeneratedData(name="bar", type=models.DataType.DIR, src="bar")],
cycles=[models.ConfigCycle(name="some_cycle", tasks=[])],
tasks=[models.ConfigShellTask(name="some_task", plugin="shell")],
data=models.ConfigData(
available=[models.ConfigAvailableData(name="foo", type=models.DataType.FILE, src="foo.txt")],
generated=[models.ConfigGeneratedData(name="bar", type=models.DataType.DIR, src="bar")],
),
parameters={},
)
Expand Down
14 changes: 7 additions & 7 deletions tests/unit_tests/parsing/test_yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,28 @@

@pytest.mark.parametrize("data_type", ["file", "dir"])
def test_base_data(data_type):
testee = models.ConfigBaseData(name={"type": data_type, "src": "foo.txt", "format": None})
testee = models.ConfigBaseData(name="name", type=data_type, src="foo.txt", format=None)

assert testee.type == data_type


@pytest.mark.parametrize("data_type", [None, "invalid", 1.42])
def test_base_data_invalid_type(data_type):
with pytest.raises(pydantic.ValidationError):
_ = models.ConfigBaseData(name={"src": "foo", "format": "nml"})
_ = models.ConfigBaseData(name="name", src="foo", format="nml")

with pytest.raises(pydantic.ValidationError):
_ = models.ConfigBaseData(name={"type": data_type, "src": "foo", "format": "nml"})
_ = models.ConfigBaseData(name="name", type=data_type, src="foo", format="nml")


def test_workflow_canonicalization():
config = models.ConfigWorkflow(
name="testee",
cycles=[models.ConfigCycle(minimal={"tasks": [models.ConfigCycleTask(a={})]})],
tasks=[{"some_task": {"plugin": "shell"}}],
cycles=[models.ConfigCycle(name="minimal", tasks=[models.ConfigCycleTask(name="a")])],
tasks=[models.ConfigShellTask(name="some_task")],
data=models.ConfigData(
available=[models.ConfigAvailableData(foo={"type": "file", "src": "foo.txt"})],
generated=[models.ConfigGeneratedData(bar={"type": "dir", "src": "bar"})],
available=[models.ConfigAvailableData(name="foo", type=models.DataType.FILE, src="foo.txt")],
generated=[models.ConfigGeneratedData(name="bar", type=models.DataType.DIR, src="bar")],
),
)

Expand Down

0 comments on commit 1931b5a

Please sign in to comment.