diff --git a/docs/_docs/tech/python-package.md b/docs/_docs/tech/python-package.md index 496870cc33..09f34c158a 100644 --- a/docs/_docs/tech/python-package.md +++ b/docs/_docs/tech/python-package.md @@ -62,16 +62,14 @@ steps -> 1 -> name ### Pipeline combinations: references -Pipelines can reference other pipelines in certain steps. -These references are accepted only in the model `PipelineWithRefs`. -This model provides a method to find and replace recursively all references. +Pipelines can reference other pipelines in certain steps: domain, append, and join. +A method `resolve_references` is provided to find and replace recursively all references. It must be called before trying to execute or translate a pipeline. - ### Pipeline : variables Some fields can contain variables instead of the actual value. -They are accepted only in the model `PipelineWithVariable`. +They are accepted only in the model `PipelineWithVariables`. This model provides a method to replace all variables by their value. It must be called before trying to execute or translate a pipeline. @@ -138,15 +136,7 @@ As of today, no translator backend exists for python. We plan to implement one f ### Summary ``` - ┌────────────────────────────┐ - │ │ - │ pipeline with references │ - │ │ - └──────────────┬─────────────┘ - │ - │ PipelineWithRefs.resolve_references - │ - ▼ + ┌────────────────────────────┐ │ │ │ pipeline with variables │ @@ -162,7 +152,7 @@ As of today, no translator backend exists for python. We plan to implement one f │ │ └──────────────┬─────────────┘ │ - │ iinput dataframes + │ input dataframes │ │ │ │ OR────────────────────────────────────────┐ │ diff --git a/server/CHANGELOG.md b/server/CHANGELOG.md index 01882ca0d4..bea20f7c2d 100644 --- a/server/CHANGELOG.md +++ b/server/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog (weaverbird python package) +## Unreleased + +### Changed + +- Remove `PipelineWithRefs`. Instead, support method `resolve_references` on types `Pipeline` and `PipelineWithVariables`. + ## [0.41.3] - 2024-01-26 ### Fixed diff --git a/server/playground.py b/server/playground.py index 9ac0cbd370..ab453ce990 100644 --- a/server/playground.py +++ b/server/playground.py @@ -62,7 +62,7 @@ from weaverbird.backends.pypika_translator.translate import ( translate_pipeline as pypika_translate_pipeline, ) -from weaverbird.pipeline.pipeline import Pipeline, PipelineWithRefs +from weaverbird.pipeline.pipeline import Pipeline, PipelineWithVariables from weaverbird.pipeline.steps import DomainStep from weaverbird.pipeline.steps.utils.combination import Reference @@ -148,7 +148,7 @@ async def prepare_pipeline(req: Request) -> Pipeline: Validate the pipeline sent in the body of the request, and prepare it for translation (resolve references and interpolate variables). """ - pipeline_with_refs = PipelineWithRefs(steps=await req.get_json()) # Validation + pipeline_with_refs = PipelineWithVariables(steps=await req.get_json()) # Validation pipeline_with_vars = await pipeline_with_refs.resolve_references(dummy_reference_resolver) pipeline = pipeline_with_vars.render(VARIABLES, nosql_apply_parameters_to_query) return pipeline @@ -347,7 +347,7 @@ async def handle_mongo_backend_request(): elif request.method == "POST": try: req_params = await parse_request_json(request) - pipeline_with_refs = PipelineWithRefs(steps=req_params["pipeline"]) # Validation + pipeline_with_refs = PipelineWithVariables(steps=req_params["pipeline"]) # Validation pipeline = await pipeline_with_refs.resolve_references(dummy_reference_resolver) mongo_query = mongo_translate_pipeline(pipeline) diff --git a/server/src/weaverbird/pipeline/pipeline.py b/server/src/weaverbird/pipeline/pipeline.py index 4b64366255..fcd4a0220a 100644 --- a/server/src/weaverbird/pipeline/pipeline.py +++ b/server/src/weaverbird/pipeline/pipeline.py @@ -1,5 +1,11 @@ from collections.abc import Iterable -from typing import Annotated, Any +from sys import version_info +from typing import Annotated, Any, TypeVar + +if version_info < (3, 11): # noqa: UP036 + from typing_extensions import Self # noqa: UP035 +else: + from typing import Self from pydantic import BaseModel, Field @@ -12,11 +18,8 @@ InclusionCondition, MatchCondition, ) -from weaverbird.pipeline.steps.append import AppendStepWithRefs -from weaverbird.pipeline.steps.domain import DomainStepWithRef from weaverbird.pipeline.steps.hierarchy import HierarchyStep -from weaverbird.pipeline.steps.join import JoinStepWithRef -from weaverbird.pipeline.steps.utils.combination import PipelineOrDomainName, ReferenceResolver +from weaverbird.pipeline.steps.utils.combination import PipelineOrDomainName, Reference, ReferenceResolver from .steps import ( AbsoluteValueStep, @@ -165,6 +168,23 @@ def model_dump(self, *, exclude_none: bool = True, **kwargs) -> dict: def dict(self, *, exclude_none: bool = True, **kwargs) -> dict: return self.model_dump(exclude_none=exclude_none, **kwargs) + async def resolve_references(self, reference_resolver: ReferenceResolver) -> Self | None: + """ + Walk the pipeline steps and replace any reference by its corresponding pipeline. + The sub-pipelines added should also be handled, so that they will be no references anymore in the result. + """ + resolved_steps: list[PipelineStep | PipelineStepWithVariables] = [] + for step in self.steps: + resolved_step = ( + await step.resolve_references(reference_resolver, self) if hasattr(step, "resolve_references") else step + ) + if isinstance(resolved_step, self.__class__): + resolved_steps.extend(resolved_step.steps) + elif resolved_step is not None: # None means the step should be skipped + resolved_steps.append(resolved_step) + + return self.__class__(steps=resolved_steps) + PipelineStepWithVariables = Annotated[ AbsoluteValueStepWithVariable @@ -245,32 +265,45 @@ def _remove_void_from_condition(condition: Condition) -> Condition | None: return condition +JoinStepMaybeWithVariables = TypeVar("JoinStepMaybeWithVariables", bound=JoinStep | JoinStepWithVariable) + + def _remove_void_condition_from_join_step( - step: JoinStepWithVariable | JoinStep, -) -> JoinStep | JoinStepWithVariable | None: - if isinstance(step.right_pipeline, str): + step: JoinStepMaybeWithVariables, +) -> JoinStepMaybeWithVariables | None: + if isinstance(step.right_pipeline, str | Reference): return step - cleaned_steps = remove_void_conditions_from_filter_steps(step.right_pipeline) - return step.__class__(**{**step.model_dump(), "right_pipeline": cleaned_steps}) if cleaned_steps else None + elif isinstance(step.right_pipeline, list): + cleaned_steps = remove_void_conditions_from_filter_steps(step.right_pipeline) + return step.__class__(**{**step.model_dump(), "right_pipeline": cleaned_steps}) if cleaned_steps else None + return None + + +AppendStepMaybeWithVariables = TypeVar("AppendStepMaybeWithVariables", bound=AppendStep | AppendStepWithVariable) def _remove_void_condition_from_append_step( - step: AppendStep | AppendStepWithVariable, -) -> AppendStep | AppendStepWithVariable | None: + step: AppendStepMaybeWithVariables, +) -> AppendStepMaybeWithVariables | None: cleaned_pipelines: list[PipelineOrDomainName] = [] for pipeline in step.pipelines: - if isinstance(pipeline, str): + if isinstance(pipeline, str | Reference): cleaned_pipelines.append(pipeline) - else: + elif isinstance(pipeline, list): if cleaned_pipeline := remove_void_conditions_from_filter_steps(pipeline): cleaned_pipelines.append(cleaned_pipeline) return step.__class__(pipelines=cleaned_pipelines) if cleaned_pipelines else None +PipelineStepMaybeWithVariables = TypeVar( + "PipelineStepMaybeWithVariables", bound=PipelineStep | PipelineStepWithVariables +) + + def remove_void_conditions_from_filter_steps( - steps: list[PipelineStepWithVariables | PipelineStep], -) -> list[PipelineStepWithVariables | PipelineStep]: + steps: list[PipelineStepMaybeWithVariables], +) -> list[PipelineStepMaybeWithVariables]: """ This method will remove all FilterStep with conditions having "__VOID__" in them. either the "value" key or the "column" key. @@ -278,9 +311,9 @@ def remove_void_conditions_from_filter_steps( final_steps = [] for step in steps: - if isinstance(step, FilterStep): + if isinstance(step, FilterStep | FilterStepWithVariables): if (condition := _remove_void_from_condition(step.condition)) is not None: - final_steps.append(FilterStep(condition=condition)) + final_steps.append(step.__class__(condition=condition)) elif isinstance(step, JoinStep | JoinStepWithVariable): if (clean_step := _remove_void_condition_from_join_step(step)) is not None: final_steps.append(clean_step) @@ -417,7 +450,7 @@ def remove_void_conditions_from_mongo_steps( # TODO move to a dedicated variables module -class PipelineWithVariables(BaseModel): +class PipelineWithVariables(Pipeline): steps: list[PipelineStepWithVariables | PipelineStep] def render(self, variables: dict[str, Any], renderer) -> Pipeline: @@ -429,41 +462,6 @@ def render(self, variables: dict[str, Any], renderer) -> Pipeline: return Pipeline(steps=steps_rendered) -PipelineStepWithRefs = Annotated[ - AppendStepWithRefs | DomainStepWithRef | JoinStepWithRef, - Field(discriminator="name"), -] - - -class PipelineWithRefs(BaseModel): - """ - Represents a pipeline in which some steps can reference some other pipelines using the syntax - `{"type": "ref", "uid": "..."}` - """ - - steps: list[PipelineStepWithRefs | PipelineStep | PipelineStepWithVariables] - - async def resolve_references(self, reference_resolver: ReferenceResolver) -> PipelineWithVariables | None: - """ - Walk the pipeline steps and replace any reference by its corresponding pipeline. - The sub-pipelines added should also be handled, so that they will be no references anymore in the result. - """ - resolved_steps: list[PipelineStepWithRefs | PipelineStepWithVariables | PipelineStep] = [] - for step in self.steps: - resolved_step = ( - await step.resolve_references(reference_resolver) if hasattr(step, "resolve_references") else step - ) - if isinstance(resolved_step, PipelineWithVariables): - resolved_steps.extend(resolved_step.steps) - elif resolved_step is not None: # None means the step should be skipped - resolved_steps.append(resolved_step) - - return PipelineWithVariables(steps=resolved_steps) - - -PipelineWithVariables.model_rebuild() - - class ReferenceUnresolved(Exception): """ Raised when a mandatory reference is not resolved diff --git a/server/src/weaverbird/pipeline/steps/__init__.py b/server/src/weaverbird/pipeline/steps/__init__.py index 86e12dc4e2..a6e078f87c 100644 --- a/server/src/weaverbird/pipeline/steps/__init__.py +++ b/server/src/weaverbird/pipeline/steps/__init__.py @@ -4,7 +4,7 @@ from .absolutevalue import AbsoluteValueStep, AbsoluteValueStepWithVariable from .addmissingdates import AddMissingDatesStep, AddMissingDatesStepWithVariables from .aggregate import AggregateStep, AggregateStepWithVariables, Aggregation -from .append import AppendStep, AppendStepWithRefs, AppendStepWithVariable +from .append import AppendStep, AppendStepWithVariable from .argmax import ArgmaxStep, ArgmaxStepWithVariable from .argmin import ArgminStep, ArgminStepWithVariable from .comparetext import CompareTextStep, CompareTextStepWithVariables @@ -16,7 +16,7 @@ from .date_extract import DateExtractStep, DateExtractStepWithVariable from .delete import DeleteStep from .dissolve import DissolveStep -from .domain import DomainStep, DomainStepWithRef +from .domain import DomainStep from .duplicate import DuplicateStep from .duration import DurationStep, DurationStepWithVariable from .evolution import EvolutionStep, EvolutionStepWithVariable @@ -26,7 +26,7 @@ from .fromdate import FromdateStep from .hierarchy import HierarchyStep from .ifthenelse import IfthenelseStep, IfThenElseStepWithVariables -from .join import JoinStep, JoinStepWithRef, JoinStepWithVariable +from .join import JoinStep, JoinStepWithVariable from .lowercase import LowercaseStep from .moving_average import MovingAverageStep from .percentage import PercentageStep diff --git a/server/src/weaverbird/pipeline/steps/append.py b/server/src/weaverbird/pipeline/steps/append.py index 85569075c4..cf7faa767c 100644 --- a/server/src/weaverbird/pipeline/steps/append.py +++ b/server/src/weaverbird/pipeline/steps/append.py @@ -1,11 +1,16 @@ -from typing import Literal +from typing import TYPE_CHECKING, Literal, Self, TypeVar from weaverbird.pipeline.steps.utils.base import BaseStep from weaverbird.pipeline.steps.utils.render_variables import StepWithVariablesMixin +if TYPE_CHECKING: + from weaverbird.pipeline.pipeline import Pipeline, PipelineWithVariables + + PipelineType = TypeVar("PipelineType", bound=Pipeline | PipelineWithVariables) + from .utils.combination import ( - PipelineOrDomainName, - PipelineWithRefsOrDomainNameOrReference, + PipelineOrDomainNameOrReference, + PipelineWithVariablesOrDomainNameOrReference, ReferenceResolver, resolve_if_reference, ) @@ -16,22 +21,20 @@ class BaseAppendStep(BaseStep): class AppendStep(BaseAppendStep): - pipelines: list[PipelineOrDomainName] - - -class AppendStepWithVariable(AppendStep, StepWithVariablesMixin): - ... + pipelines: list[PipelineOrDomainNameOrReference] - -class AppendStepWithRefs(BaseAppendStep): - pipelines: list[PipelineWithRefsOrDomainNameOrReference] - - async def resolve_references(self, reference_resolver: ReferenceResolver) -> AppendStepWithVariable | None: + async def resolve_references( + self, reference_resolver: ReferenceResolver, parent_pipeline: "PipelineType" + ) -> Self | None: resolved_pipelines = [await resolve_if_reference(reference_resolver, p) for p in self.pipelines] resolved_pipelines_without_nones = [p for p in resolved_pipelines if p is not None] if len(resolved_pipelines_without_nones) == 0: return None # skip the step - return AppendStepWithVariable( + return self.__class__( name=self.name, pipelines=resolved_pipelines_without_nones, ) + + +class AppendStepWithVariable(AppendStep, StepWithVariablesMixin): + pipelines: list[PipelineWithVariablesOrDomainNameOrReference] diff --git a/server/src/weaverbird/pipeline/steps/domain.py b/server/src/weaverbird/pipeline/steps/domain.py index 625106ecc8..ca666fd8a2 100644 --- a/server/src/weaverbird/pipeline/steps/domain.py +++ b/server/src/weaverbird/pipeline/steps/domain.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Literal, Union +from typing import TYPE_CHECKING, Literal, TypeVar from weaverbird.pipeline.steps.utils.base import BaseStep from weaverbird.pipeline.steps.utils.combination import ( @@ -8,33 +8,28 @@ ) if TYPE_CHECKING: - from weaverbird.pipeline.pipeline import PipelineWithVariables + from weaverbird.pipeline.pipeline import Pipeline, PipelineWithVariables + PipelineType = TypeVar("PipelineType", bound=Pipeline | PipelineWithVariables) -class BaseDomainStep(BaseStep): - name: Literal["domain"] = "domain" - - -class DomainStep(BaseDomainStep): - domain: str - -class DomainStepWithRef(BaseDomainStep): +class DomainStep(BaseStep): + name: Literal["domain"] = "domain" domain: str | Reference async def resolve_references( - self, reference_resolver: ReferenceResolver - ) -> Union[DomainStep, "PipelineWithVariables"]: + self, reference_resolver: ReferenceResolver, parent_pipeline: "PipelineType" + ) -> "DomainStep | PipelineType": """ This resolution can return a whole pipeline, which needs to replace the step. Not that the resulting array must be flattened: it should look like [step 1, step 2, step 3], not [[step 1, step 2], step 3] """ - from weaverbird.pipeline.pipeline import PipelineWithRefs, ReferenceUnresolved + from weaverbird.pipeline.pipeline import ReferenceUnresolved resolved = await resolve_if_reference(reference_resolver, self.domain) if isinstance(resolved, list): - return await PipelineWithRefs(steps=resolved).resolve_references(reference_resolver) + return await parent_pipeline.__class__(steps=resolved).resolve_references(reference_resolver) elif resolved is None: raise ReferenceUnresolved() else: diff --git a/server/src/weaverbird/pipeline/steps/join.py b/server/src/weaverbird/pipeline/steps/join.py index 4f2ca67b63..980674c22b 100644 --- a/server/src/weaverbird/pipeline/steps/join.py +++ b/server/src/weaverbird/pipeline/steps/join.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import TYPE_CHECKING, Literal, Self, TypeVar from pydantic import Field @@ -6,9 +6,14 @@ from weaverbird.pipeline.steps.utils.render_variables import StepWithVariablesMixin from weaverbird.pipeline.types import ColumnName +if TYPE_CHECKING: + from weaverbird.pipeline.pipeline import Pipeline, PipelineWithVariables + + PipelineType = TypeVar("PipelineType", bound=Pipeline | PipelineWithVariables) + from .utils.combination import ( - PipelineOrDomainName, - PipelineWithRefsOrDomainNameOrReference, + PipelineOrDomainNameOrReference, + PipelineWithVariablesOrDomainNameOrReference, ReferenceResolver, resolve_if_reference, ) @@ -23,25 +28,23 @@ class BaseJoinStep(BaseStep): class JoinStep(BaseJoinStep): - right_pipeline: PipelineOrDomainName - - -class JoinStepWithVariable(JoinStep, StepWithVariablesMixin): - ... + right_pipeline: PipelineOrDomainNameOrReference - -class JoinStepWithRef(BaseJoinStep): - right_pipeline: PipelineWithRefsOrDomainNameOrReference - - async def resolve_references(self, reference_resolver: ReferenceResolver) -> JoinStepWithVariable | None: + async def resolve_references( + self, reference_resolver: ReferenceResolver, parent_pipeline: "PipelineType" + ) -> Self | None: right_pipeline = await resolve_if_reference(reference_resolver, self.right_pipeline) if right_pipeline is None: from weaverbird.pipeline.pipeline import ReferenceUnresolved raise ReferenceUnresolved() - return JoinStepWithVariable( + return self.__class__( name=self.name, type=self.type, on=self.on, right_pipeline=right_pipeline, ) + + +class JoinStepWithVariable(JoinStep, StepWithVariablesMixin): + right_pipeline: PipelineWithVariablesOrDomainNameOrReference diff --git a/server/src/weaverbird/pipeline/steps/utils/combination.py b/server/src/weaverbird/pipeline/steps/utils/combination.py index 0a6fb0fb34..73247d310c 100644 --- a/server/src/weaverbird/pipeline/steps/utils/combination.py +++ b/server/src/weaverbird/pipeline/steps/utils/combination.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, BeforeValidator, TypeAdapter if TYPE_CHECKING: - from weaverbird.pipeline.pipeline import PipelineStep, PipelineStepWithRefs, PipelineStepWithVariables + from weaverbird.pipeline.pipeline import PipelineStep, PipelineStepWithVariables class Reference(BaseModel): @@ -55,14 +55,12 @@ def _pipelinestep_adapter() -> TypeAdapter["str | list[PipelineStep]"]: @cache -def _pipelinestepwithref_adapter() -> ( - TypeAdapter["str | list[PipelineStepWithRefs | PipelineStepWithVariables | PipelineStep]"] -): - from weaverbird.pipeline.pipeline import PipelineStep, PipelineStepWithRefs, PipelineStepWithVariables +def _pipelinestepwithvariables_adapter() -> TypeAdapter["str | list[PipelineStepWithVariables | PipelineStep]"]: + from weaverbird.pipeline.pipeline import PipelineStep, PipelineStepWithVariables - # Note: PipelineStep must appear _before_ PipelineStepWithVariables, to avoid fields that could contain variables - # to always be matched as strings - return TypeAdapter(str | list[PipelineStepWithRefs | PipelineStep | PipelineStepWithVariables]) # type: ignore[arg-type] + # mypy is confused by the type with a postponed annotation above, so it expects str | list[Any] + # here + return TypeAdapter(str | list[PipelineStep | PipelineStepWithVariables]) # type: ignore[arg-type] def _ensure_is_pipeline_step( @@ -71,25 +69,21 @@ def _ensure_is_pipeline_step( return _pipelinestep_adapter().validate_python(v) -# can be either a domain name or a complete pipeline -PipelineOrDomainName = Annotated[str | list["PipelineStep"], BeforeValidator(_ensure_is_pipeline_step)] - - -def _ensure_is_pipeline_step_with_ref( - v: str | list[dict] | list["PipelineStep | PipelineStepWithVariables | PipelineStepWithRefs"], -) -> str | list["PipelineStep | PipelineStepWithVariables | PipelineStepWithRefs"]: - return _pipelinestepwithref_adapter().validate_python(v) +def _ensure_is_pipeline_step_with_variables( + v: str | list[dict] | list["PipelineStepWithVariables | PipelineStep"], +) -> str | list["PipelineStepWithVariables | PipelineStep"]: + return _pipelinestepwithvariables_adapter().validate_python(v) # can be either a domain name or a complete pipeline -PipelineWithRefsOrDomainName = Annotated[ - str | list["PipelineStepWithRefs | PipelineStepWithVariables | PipelineStep"], - BeforeValidator(_ensure_is_pipeline_step_with_ref), +PipelineOrDomainName = Annotated[str | list["PipelineStep"], BeforeValidator(_ensure_is_pipeline_step)] +PipelineWithVariablesOrDomainName = Annotated[ + str | list["PipelineStepWithVariables | PipelineStep"], BeforeValidator(_ensure_is_pipeline_step_with_variables) ] PipelineOrDomainNameOrReference = PipelineOrDomainName | Reference -PipelineWithRefsOrDomainNameOrReference = PipelineWithRefsOrDomainName | Reference +PipelineWithVariablesOrDomainNameOrReference = PipelineWithVariablesOrDomainName | Reference # A reference returning None means that it should be skipped ReferenceResolver = Callable[[Reference], Awaitable[PipelineOrDomainName | None]] @@ -119,12 +113,12 @@ async def resolve_if_reference( async def _resolve_references_in_pipeline( reference_resolver: ReferenceResolver, - pipeline: list["PipelineStepWithRefs | PipelineStep"], + pipeline: list["PipelineStep"], ) -> PipelineOrDomainName | None: - from weaverbird.pipeline.pipeline import PipelineWithRefs, ReferenceUnresolved + from weaverbird.pipeline.pipeline import Pipeline, ReferenceUnresolved # Recursively resolve any reference in sub-pipelines - pipeline_with_refs = PipelineWithRefs(steps=pipeline) + pipeline_with_refs = Pipeline(steps=pipeline) try: pipeline_without_refs = await pipeline_with_refs.resolve_references(reference_resolver) return pipeline_without_refs.model_dump()["steps"] diff --git a/server/tests/backends/sql_translator_unit_tests/test_base_translator.py b/server/tests/backends/sql_translator_unit_tests/test_base_translator.py index 0eff59e2e4..037a5ffdf2 100644 --- a/server/tests/backends/sql_translator_unit_tests/test_base_translator.py +++ b/server/tests/backends/sql_translator_unit_tests/test_base_translator.py @@ -345,7 +345,7 @@ def test_domain_with_wrong_domain_name(base_translator: BaseTranslator): def test_domain_with_reference(base_translator: BaseTranslator): uid = "to be or not to be a query ?" type = "ref" - step = steps.DomainStepWithRef(domain=Reference(type=type, uid=uid)) + step = steps.DomainStep(domain=Reference(type=type, uid=uid)) with pytest.raises(NotImplementedError): base_translator._domain(step=step) diff --git a/server/tests/pipeline/steps/utils/test_combination.py b/server/tests/pipeline/steps/utils/test_combination.py index ec200d72c3..9c0ad1a535 100644 --- a/server/tests/pipeline/steps/utils/test_combination.py +++ b/server/tests/pipeline/steps/utils/test_combination.py @@ -1,8 +1,7 @@ from typing import Any import pytest -from pydantic import ValidationError -from weaverbird.pipeline.steps import AppendStep, AppendStepWithRefs +from weaverbird.pipeline.steps import AppendStep @pytest.fixture @@ -80,13 +79,6 @@ def raw_append_step_with_refs() -> dict[str, Any]: } -def test_pipelinestep_validation(raw_append_step: dict[str, Any], raw_append_step_with_refs: dict[str, Any]) -> None: - step = AppendStep(**raw_append_step) - assert isinstance(step, AppendStep) - with pytest.raises(ValidationError): - AppendStep(**raw_append_step_with_refs) - - def test_pipelinestepwith_refs_validation(raw_append_step_with_refs: dict[str, Any]) -> None: - step = AppendStepWithRefs(**raw_append_step_with_refs) - assert isinstance(step, AppendStepWithRefs) + step = AppendStep(**raw_append_step_with_refs) + assert isinstance(step, AppendStep) diff --git a/server/tests/pipeline/test_references.py b/server/tests/pipeline/test_references.py index ce5b256e9c..b415ae8fdb 100644 --- a/server/tests/pipeline/test_references.py +++ b/server/tests/pipeline/test_references.py @@ -1,19 +1,17 @@ import pytest from weaverbird.pipeline.pipeline import ( Pipeline, - PipelineWithRefs, PipelineWithVariables, ReferenceUnresolved, ) from weaverbird.pipeline.steps import ( - AppendStepWithRefs, + AppendStep, DomainStep, - DomainStepWithRef, - JoinStepWithRef, + FilterStepWithVariables, + JoinStep, TextStep, + TextStepWithVariable, ) -from weaverbird.pipeline.steps.append import AppendStepWithVariable -from weaverbird.pipeline.steps.join import JoinStepWithVariable from weaverbird.pipeline.steps.utils.combination import Reference PIPELINES_LIBRARY: dict[str, list[dict]] = { @@ -22,9 +20,9 @@ DomainStep(domain="source"), ] ).dict()["steps"], - "intermediate_pipeline": PipelineWithRefs( + "intermediate_pipeline": Pipeline( steps=[ - DomainStepWithRef(domain=Reference(uid="source_pipeline")), + DomainStep(domain=Reference(uid="source_pipeline")), TextStep(new_column="meow", text="Cat"), ] ).dict()["steps"], @@ -34,9 +32,9 @@ TextStep(new_column="comes_from", text="other"), ] ).dict()["steps"], - "pipeline_with_unresolved_ref": PipelineWithRefs( + "pipeline_with_unresolved_ref": Pipeline( steps=[ - DomainStepWithRef(domain=Reference(uid="unresolved")), + DomainStep(domain=Reference(uid="unresolved")), TextStep(new_column="fail", text="yes"), ] ).dict()["steps"], @@ -49,13 +47,13 @@ async def reference_resolver(ref: Reference) -> list[dict] | None: @pytest.mark.asyncio async def test_resolve_references_domain(): - pipeline_with_refs = PipelineWithRefs( + pipeline_with_refs = Pipeline( steps=[ - DomainStepWithRef(domain=Reference(uid="source_pipeline")), + DomainStep(domain=Reference(uid="source_pipeline")), TextStep(new_column="text", text="Lorem ipsum"), ] ) - assert await pipeline_with_refs.resolve_references(reference_resolver) == PipelineWithVariables( + assert await pipeline_with_refs.resolve_references(reference_resolver) == Pipeline( steps=[ DomainStep(domain="source"), TextStep(new_column="text", text="Lorem ipsum"), @@ -65,13 +63,13 @@ async def test_resolve_references_domain(): @pytest.mark.asyncio async def test_resolve_references_recursive(): - pipeline_with_refs = PipelineWithRefs( + pipeline_with_refs = Pipeline( steps=[ - DomainStepWithRef(domain=Reference(uid="intermediate_pipeline")), + DomainStep(domain=Reference(uid="intermediate_pipeline")), TextStep(new_column="text", text="Lorem ipsum"), ] ) - assert await pipeline_with_refs.resolve_references(reference_resolver) == PipelineWithVariables( + assert await pipeline_with_refs.resolve_references(reference_resolver) == Pipeline( steps=[ DomainStep(domain="source"), TextStep(new_column="meow", text="Cat"), @@ -82,10 +80,10 @@ async def test_resolve_references_recursive(): @pytest.mark.asyncio async def test_resolve_references_join(): - pipeline_with_refs = PipelineWithRefs( + pipeline_with_refs = Pipeline( steps=[ - DomainStepWithRef(domain="source"), - JoinStepWithRef( + DomainStep(domain="source"), + JoinStep( on=[("key_left", "key_right")], right_pipeline=Reference(uid="other_pipeline"), type="left", @@ -94,10 +92,10 @@ async def test_resolve_references_join(): ] ) - expected = PipelineWithVariables( + expected = Pipeline( steps=[ DomainStep(domain="source"), - JoinStepWithVariable( + JoinStep( on=[("key_left", "key_right")], right_pipeline=PIPELINES_LIBRARY["other_pipeline"], type="left", @@ -111,10 +109,10 @@ async def test_resolve_references_join(): @pytest.mark.asyncio async def test_resolve_references_append(): - pipeline_with_refs = PipelineWithRefs( + pipeline_with_refs = Pipeline( steps=[ DomainStep(domain="source"), - AppendStepWithRefs( + AppendStep( pipelines=[ Reference(uid="other_pipeline"), ] @@ -123,33 +121,33 @@ async def test_resolve_references_append(): ] ) - expected = PipelineWithVariables( + expected = Pipeline( steps=[ DomainStep(domain="source"), - AppendStepWithVariable(pipelines=[PIPELINES_LIBRARY["other_pipeline"]]), + AppendStep(pipelines=[PIPELINES_LIBRARY["other_pipeline"]]), TextStep(new_column="text", text="Lorem ipsum"), ] ) assert await pipeline_with_refs.resolve_references(reference_resolver) == expected - pipeline_with_refs = PipelineWithRefs( + pipeline_with_refs = Pipeline( steps=[ DomainStep(domain="source"), - AppendStepWithRefs( + AppendStep( pipelines=[ [ - DomainStepWithRef(domain=Reference(uid="other_pipeline")), + DomainStep(domain=Reference(uid="other_pipeline")), TextStep(new_column="text", text="Lorem ipsum"), ], ] ), ] ) - expected = PipelineWithVariables( + expected = Pipeline( steps=[ DomainStep(domain="source"), - AppendStepWithVariable( + AppendStep( pipelines=[[*PIPELINES_LIBRARY["other_pipeline"], TextStep(new_column="text", text="Lorem ipsum")]] ), ] @@ -163,10 +161,10 @@ async def test_resolve_references_unresolved_append(): """ It should skip pipelines that are not resolved in an append step """ - pipeline_with_refs = PipelineWithRefs( + pipeline_with_refs = Pipeline( steps=[ DomainStep(domain="source"), - AppendStepWithRefs( + AppendStep( pipelines=[ Reference(uid="unresolved"), Reference(uid="other_pipeline"), @@ -177,10 +175,10 @@ async def test_resolve_references_unresolved_append(): ] ) - expected = PipelineWithVariables( + expected = Pipeline( steps=[ DomainStep(domain="source"), - AppendStepWithVariable(pipelines=[PIPELINES_LIBRARY["other_pipeline"]]), + AppendStep(pipelines=[PIPELINES_LIBRARY["other_pipeline"]]), TextStep(new_column="text", text="Lorem ipsum"), ] ) @@ -192,10 +190,10 @@ async def test_resolve_references_unresolved_append_all(): """ It should skip the append step if all its pipelines are not resolved """ - pipeline_with_refs = PipelineWithRefs( + pipeline_with_refs = Pipeline( steps=[ DomainStep(domain="source"), - AppendStepWithRefs( + AppendStep( pipelines=[ Reference(uid="unresolved"), Reference(uid="unresolved_2"), @@ -205,7 +203,7 @@ async def test_resolve_references_unresolved_append_all(): ] ) - assert await pipeline_with_refs.resolve_references(reference_resolver) == PipelineWithVariables( + assert await pipeline_with_refs.resolve_references(reference_resolver) == Pipeline( steps=[ DomainStep(domain="source"), TextStep(new_column="text", text="Lorem ipsum"), @@ -218,9 +216,9 @@ async def test_resolve_references_unresolved_domain(): """ It should raise an error if the domain reference is not resolved """ - pipeline_with_refs = PipelineWithRefs( + pipeline_with_refs = Pipeline( steps=[ - DomainStepWithRef(domain=Reference(uid="unresolved")), + DomainStep(domain=Reference(uid="unresolved")), TextStep(new_column="text", text="Lorem ipsum"), ] ) @@ -233,10 +231,10 @@ async def test_resolve_references_unresolved_append_subpipeline_error(): """ It should skip pipelines that trigger a resolution error """ - pipeline_with_refs = PipelineWithRefs( + pipeline_with_refs = Pipeline( steps=[ DomainStep(domain="source"), - AppendStepWithRefs( + AppendStep( pipelines=[ Reference(uid="pipeline_with_unresolved_ref"), Reference(uid="other_pipeline"), @@ -246,10 +244,10 @@ async def test_resolve_references_unresolved_append_subpipeline_error(): ] ) - expected = PipelineWithVariables( + expected = Pipeline( steps=[ DomainStep(domain="source"), - AppendStepWithVariable(pipelines=[PIPELINES_LIBRARY["other_pipeline"]]), + AppendStep(pipelines=[PIPELINES_LIBRARY["other_pipeline"]]), TextStep(new_column="text", text="Lorem ipsum"), ] ) @@ -262,10 +260,10 @@ async def test_resolve_references_unresolved_join(): """ It should raise an error if the joined pipeline is not resolved """ - pipeline_with_refs = PipelineWithRefs( + pipeline_with_refs = Pipeline( steps=[ - DomainStepWithRef(domain="source"), - JoinStepWithRef( + DomainStep(domain="source"), + JoinStep( on=[("key_left", "key_right")], right_pipeline=Reference(uid="unresolved"), type="left", @@ -282,10 +280,10 @@ async def test_resolve_references_unresolved_join_subpipeline_error(): """ It should raise an error if the joined pipeline raises a resolution error """ - pipeline_with_refs = PipelineWithRefs( + pipeline_with_refs = Pipeline( steps=[ - DomainStepWithRef(domain="source"), - JoinStepWithRef( + DomainStep(domain="source"), + JoinStep( on=[("key_left", "key_right")], right_pipeline=Reference(uid="pipeline_with_unresolved_step"), type="left", @@ -295,3 +293,42 @@ async def test_resolve_references_unresolved_join_subpipeline_error(): ) with pytest.raises(ReferenceUnresolved): await pipeline_with_refs.resolve_references(reference_resolver) + + +@pytest.mark.asyncio +async def test_resolve_references_with_variables(): + pipeline_with_refs = PipelineWithVariables( + steps=[ + DomainStep(domain=Reference(uid="intermediate_pipeline")), + FilterStepWithVariables( + condition={ + "column": "date", + "operator": "from", + "value": { + "quantity": 1, + "duration": "year", + "operator": "before", + "date": "{{ TODAY }}", + }, + } + ), + ] + ) + assert await pipeline_with_refs.resolve_references(reference_resolver) == PipelineWithVariables( + steps=[ + DomainStep(domain="source"), + TextStepWithVariable(new_column="meow", text="Cat"), + FilterStepWithVariables( + condition={ + "column": "date", + "operator": "from", + "value": { + "quantity": 1, + "duration": "year", + "operator": "before", + "date": "{{ TODAY }}", + }, + } + ), + ] + ) diff --git a/server/tests/test_pipeline.py b/server/tests/test_pipeline.py index fcf31fe50f..2974e29f04 100644 --- a/server/tests/test_pipeline.py +++ b/server/tests/test_pipeline.py @@ -12,7 +12,6 @@ Pipeline, PipelineStep, PipelineStepWithVariables, - PipelineWithRefs, PipelineWithVariables, remove_void_conditions_from_filter_steps, remove_void_conditions_from_mongo_steps, @@ -602,7 +601,7 @@ def test_remove_void_conditions_from_filter_steps_with_combinations( def test_pipeline_with_refs_variables_and_date_validity(): - PipelineWithRefs( + PipelineWithVariables( steps=[ {"domain": {"type": "ref", "uid": "83ff1fa2-d186-4a7b-a53b-47c901a076c7"}, "name": "domain"}, {