diff --git a/src/sirocco/core/graph_items.py b/src/sirocco/core/graph_items.py index dd42fb7c..2453d2b3 100644 --- a/src/sirocco/core/graph_items.py +++ b/src/sirocco/core/graph_items.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeAlias from sirocco.parsing._yaml_data_models import ( - CanonicalAvailableData, + ConfigAvailableData, ConfigBaseDataSpecs, ConfigBaseTaskSpecs, ) @@ -47,7 +47,7 @@ def from_config(cls, config: ConfigBaseData, coordinates: dict) -> Self: name=config.name, type=config.type, src=config.src, - available=isinstance(config, CanonicalAvailableData), + available=isinstance(config, ConfigAvailableData), coordinates=coordinates, ) diff --git a/src/sirocco/parsing/_yaml_data_models.py b/src/sirocco/parsing/_yaml_data_models.py index 8e6bf485..1523e195 100644 --- a/src/sirocco/parsing/_yaml_data_models.py +++ b/src/sirocco/parsing/_yaml_data_models.py @@ -26,46 +26,64 @@ class _NamedBaseModel(BaseModel): - """Base class for all classes with a key that specifies their name. - - For example: - - .. yaml + """ + Noninvasive base model for reading names from yaml keys *or* attributes. - - property_name: - property: true + Examples: - When parsing with this as parent class it is converted to - `{"name": "propery_name", "property": True}`. + >>> NeoNameBase(name="foo") + NeoNameBase(name='foo') + + >>> NeoNameBase(foo={}) + NeoNameBase(name='foo') + + >>> import pydantic_yaml, textwrap + >>> pydantic_yaml.parse_yaml_raw_as( + ... NeoNameBase, + ... textwrap.dedent(''' + ... foo: + ... '''), + ... ) + NeoNameBase(name='foo') + + >>> pydantic_yaml.parse_yaml_raw_as( + ... NeoNameBase, + ... textwrap.dedent(''' + ... name: foo + ... '''), + ... ) + NeoNameBase(name='foo') """ name: str - def __init__(self, /, **data): - super().__init__(**self.merge_name_and_specs(data)) - - @staticmethod - def merge_name_and_specs(data: dict) -> dict: - """ - Converts dict of form - - `{my_name: {'spec_0': ..., ..., 'spec_n': ...}` - - to - - `{'name': my_name, 'spec_0': ..., ..., 'spec_n': ...}` + @model_validator(mode="before") + @classmethod + def reformat_named_object(cls, data: Any) -> Any: + return cls.extract_merge_name(data) - by copy. - """ - name_and_spec = {} - if len(data) != 1: - msg = f"Expected dict with one element of the form {{'name': specification}} but got {data}." - raise ValueError(msg) - name_and_spec["name"] = next(iter(data.keys())) - # if no specification specified e.g. "- my_name:" - if (spec := next(iter(data.values()))) is not None: - name_and_spec.update(spec) - return name_and_spec + @classmethod + def extract_merge_name(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + if len(data) == 1: + key, value = next(iter(data.items())) + match key: + case str(): + match value: + case str() if key == "name": + pass + case dict() if "name" not in value: + data = value | {"name": key} + case None: + data = {"name": key} + case _: + msg = f"{cls.__name__} may only be used for named objects, not values (got {data})." + raise TypeError(msg) + case _: + msg = f"{cls.__name__} requires name to be a str (got {key})." + raise TypeError(msg) + return data class _WhenBaseModel(BaseModel): @@ -471,7 +489,7 @@ class ConfigBaseData(_NamedBaseModel, ConfigBaseDataSpecs): from python: - >>> ConfigBaseData(foo={"type": "file", "src": "foo.txt"}) + >>> ConfigBaseData(name="foo", type=DataType.FILE, src="foo.txt") ConfigBaseData(type=, src='foo.txt', format=None, computer=None, name='foo', parameters=[]) """ @@ -502,48 +520,6 @@ def invalid_field(cls, value: str | None) -> str | None: return value -class CanonicalBaseData(BaseModel, ConfigBaseDataSpecs): - """ - Canonical data model. - - Examples: - - >>> CanonicalBaseData(name="foo", type=DataType.FILE, src="foo.txt") - CanonicalBaseData(type=, src='foo.txt', format=None, computer=None, name='foo', parameters=[]) - """ - - name: str - parameters: list[str] = [] - - -class CanonicalAvailableData(CanonicalBaseData): - pass - - -class CanonicalGeneratedData(CanonicalBaseData): - pass - - -def canonicalize_available_data(value: ConfigAvailableData) -> CanonicalAvailableData: - return CanonicalAvailableData( - name=value.name, - type=value.type, - src=value.src, - format=value.format, - parameters=value.parameters, - ) - - -def canonicalize_generated_data(value: ConfigGeneratedData) -> CanonicalGeneratedData: - return CanonicalGeneratedData( - name=value.name, - type=value.type, - src=value.src, - format=value.format, - parameters=value.parameters, - ) - - class ConfigData(BaseModel): """ To create the container of available and generated data @@ -561,9 +537,9 @@ class ConfigData(BaseModel): ... type: "file" ... src: "foo.txt" ... generated: - ... - bar: - ... type: "file" - ... src: "bar.txt" + ... - name: "bar" + ... type: "file" + ... src: "bar.txt" ... ''' ... ) >>> data = pydantic_yaml.parse_yaml_raw_as(ConfigData, snippet) @@ -580,24 +556,12 @@ class ConfigData(BaseModel): generated: list[ConfigGeneratedData] = [] -class CanonicalData(BaseModel): - available: list[CanonicalAvailableData] = [] - generated: list[CanonicalGeneratedData] = [] - - -def canonicalize_data(value: ConfigData) -> CanonicalData: - return CanonicalData( - available=[canonicalize_available_data(i) for i in value.available], - generated=[canonicalize_generated_data(i) for i in value.generated], - ) - - def get_plugin_from_named_base_model( data: dict | ConfigRootTask | ConfigShellTask | ConfigIconTask, ) -> str: if isinstance(data, (ConfigRootTask, ConfigShellTask, ConfigIconTask)): return data.plugin - name_and_specs = _NamedBaseModel.merge_name_and_specs(data) + name_and_specs = _NamedBaseModel.extract_merge_name(data) if name_and_specs.get("name", None) == "ROOT": return ConfigRootTask.plugin plugin = name_and_specs.get("plugin", None) @@ -698,11 +662,11 @@ class CanonicalWorkflow(BaseModel): rootdir: Path cycles: Annotated[list[ConfigCycle], AfterValidator(list_not_empty)] tasks: Annotated[list[ConfigTask], AfterValidator(list_not_empty)] - data: CanonicalData + data: ConfigData parameters: dict[str, list[Any]] @property - def data_dict(self) -> dict[str, CanonicalAvailableData | CanonicalGeneratedData]: + def data_dict(self) -> dict[str, ConfigAvailableData | ConfigGeneratedData]: return {data.name: data for data in itertools.chain(self.data.available, self.data.generated)} @property @@ -714,13 +678,12 @@ def canonicalize_workflow(value: ConfigWorkflow, rootdir: Path) -> CanonicalWork if not value.name: msg = "Workflow name required for canonicalization." raise ValueError(msg) - canon_data = canonicalize_data(value.data) return CanonicalWorkflow( name=value.name, rootdir=rootdir, cycles=value.cycles, tasks=value.tasks, - data=canon_data, + data=value.data, parameters=value.parameters, ) diff --git a/tests/unit_tests/core/test_workflow.py b/tests/unit_tests/core/test_workflow.py index 9a25f1d1..40b45d06 100644 --- a/tests/unit_tests/core/test_workflow.py +++ b/tests/unit_tests/core/test_workflow.py @@ -9,11 +9,11 @@ def test_minimal_workflow(): minimal_config = models.CanonicalWorkflow( name="minimal", rootdir=pathlib.Path("minimal"), - cycles=[models.ConfigCycle(some_cycle={"tasks": []})], - tasks=[models.ConfigShellTask(some_task={"plugin": "shell"})], - data=models.CanonicalData( - available=[models.CanonicalAvailableData(name="foo", type=models.DataType.FILE, src="foo.txt")], - generated=[models.CanonicalGeneratedData(name="bar", type=models.DataType.DIR, src="bar")], + cycles=[models.ConfigCycle(name="some_cycle", tasks=[])], + tasks=[models.ConfigShellTask(name="some_task", plugin="shell")], + data=models.ConfigData( + available=[models.ConfigAvailableData(name="foo", type=models.DataType.FILE, src="foo.txt")], + generated=[models.ConfigGeneratedData(name="bar", type=models.DataType.DIR, src="bar")], ), parameters={}, ) diff --git a/tests/unit_tests/parsing/test_yaml_data_models.py b/tests/unit_tests/parsing/test_yaml_data_models.py index 87461ca1..7bc695a1 100644 --- a/tests/unit_tests/parsing/test_yaml_data_models.py +++ b/tests/unit_tests/parsing/test_yaml_data_models.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("data_type", ["file", "dir"]) def test_base_data(data_type): - testee = models.ConfigBaseData(name={"type": data_type, "src": "foo.txt", "format": None}) + testee = models.ConfigBaseData(name="name", type=data_type, src="foo.txt", format=None) assert testee.type == data_type @@ -17,20 +17,20 @@ def test_base_data(data_type): @pytest.mark.parametrize("data_type", [None, "invalid", 1.42]) def test_base_data_invalid_type(data_type): with pytest.raises(pydantic.ValidationError): - _ = models.ConfigBaseData(name={"src": "foo", "format": "nml"}) + _ = models.ConfigBaseData(name="name", src="foo", format="nml") with pytest.raises(pydantic.ValidationError): - _ = models.ConfigBaseData(name={"type": data_type, "src": "foo", "format": "nml"}) + _ = models.ConfigBaseData(name="name", type=data_type, src="foo", format="nml") def test_workflow_canonicalization(): config = models.ConfigWorkflow( name="testee", - cycles=[models.ConfigCycle(minimal={"tasks": [models.ConfigCycleTask(a={})]})], - tasks=[{"some_task": {"plugin": "shell"}}], + cycles=[models.ConfigCycle(name="minimal", tasks=[models.ConfigCycleTask(name="a")])], + tasks=[models.ConfigShellTask(name="some_task")], data=models.ConfigData( - available=[models.ConfigAvailableData(foo={"type": "file", "src": "foo.txt"})], - generated=[models.ConfigGeneratedData(bar={"type": "dir", "src": "bar"})], + available=[models.ConfigAvailableData(name="foo", type=models.DataType.FILE, src="foo.txt")], + generated=[models.ConfigGeneratedData(name="bar", type=models.DataType.DIR, src="bar")], ), )