From 894679921526ca366696544ab039e41aa16363fa Mon Sep 17 00:00:00 2001 From: Jed Cunningham Date: Wed, 1 Jan 2025 23:45:58 -0500 Subject: [PATCH 1/3] Add a couple DAG bundle related helpers This is broken out of the larger changes adding DAG bundle parsing, to make reviewing that (eventual) PR a bit easier. --- airflow/models/dag.py | 12 ++++++++++++ airflow/models/dagbundle.py | 11 +++++++++++ 2 files changed, 23 insertions(+) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 4abf24af52537..24aa5dac7eee1 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -770,6 +770,18 @@ def get_is_paused(self, session=NEW_SESSION) -> None: """Return a boolean indicating whether this DAG is paused.""" return session.scalar(select(DagModel.is_paused).where(DagModel.dag_id == self.dag_id)) + @property + @provide_session + def bundle_name(self, session=NEW_SESSION) -> str: + """Return the name of the bundle this DAG is in.""" + return session.scalar(select(DagModel.bundle_name).where(DagModel.dag_id == self.dag_id)) + + @property + @provide_session + def latest_bundle_version(self, session=NEW_SESSION) -> str | None: + """Return the latest version of the bundle this DAG is in.""" + return session.scalar(select(DagModel.latest_bundle_version).where(DagModel.dag_id == self.dag_id)) + @methodtools.lru_cache(maxsize=None) @classmethod def get_serialized_fields(cls): diff --git a/airflow/models/dagbundle.py b/airflow/models/dagbundle.py index 08429db0b0bcb..43f6396c5bf5f 100644 --- a/airflow/models/dagbundle.py +++ b/airflow/models/dagbundle.py @@ -16,11 +16,17 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING + from sqlalchemy import Boolean, Column, String from airflow.models.base import Base, StringID +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime +if TYPE_CHECKING: + from sqlalchemy.orm import Session + class DagBundleModel(Base): """ @@ -41,3 +47,8 @@ class DagBundleModel(Base): def __init__(self, *, name: str): self.name = name + + @staticmethod + @provide_session + def get(name: str, session: Session = NEW_SESSION) -> DagBundleModel: + return session.query(DagBundleModel).get(name) From 38ce26b9486bba9bb7069c44d63386e533e1515e Mon Sep 17 00:00:00 2001 From: Jed Cunningham Date: Thu, 2 Jan 2025 11:56:41 -0500 Subject: [PATCH 2/3] move away from property so session can be passed --- airflow/models/dag.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 24aa5dac7eee1..f4e593dde91da 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -770,15 +770,13 @@ def get_is_paused(self, session=NEW_SESSION) -> None: """Return a boolean indicating whether this DAG is paused.""" return session.scalar(select(DagModel.is_paused).where(DagModel.dag_id == self.dag_id)) - @property @provide_session - def bundle_name(self, session=NEW_SESSION) -> str: + def get_bundle_name(self, session=NEW_SESSION) -> str: """Return the name of the bundle this DAG is in.""" return session.scalar(select(DagModel.bundle_name).where(DagModel.dag_id == self.dag_id)) - @property @provide_session - def latest_bundle_version(self, session=NEW_SESSION) -> str | None: + def get_latest_bundle_version(self, session=NEW_SESSION) -> str | None: """Return the latest version of the bundle this DAG is in.""" return session.scalar(select(DagModel.latest_bundle_version).where(DagModel.dag_id == self.dag_id)) From c0cf93683cc23093f5ef8567cf6739b35eb44daa Mon Sep 17 00:00:00 2001 From: Jed Cunningham Date: Sat, 4 Jan 2025 14:55:32 -0600 Subject: [PATCH 3/3] Remove DagBundleModel.get() --- airflow/models/dagbundle.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/airflow/models/dagbundle.py b/airflow/models/dagbundle.py index 43f6396c5bf5f..08429db0b0bcb 100644 --- a/airflow/models/dagbundle.py +++ b/airflow/models/dagbundle.py @@ -16,17 +16,11 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING - from sqlalchemy import Boolean, Column, String from airflow.models.base import Base, StringID -from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime -if TYPE_CHECKING: - from sqlalchemy.orm import Session - class DagBundleModel(Base): """ @@ -47,8 +41,3 @@ class DagBundleModel(Base): def __init__(self, *, name: str): self.name = name - - @staticmethod - @provide_session - def get(name: str, session: Session = NEW_SESSION) -> DagBundleModel: - return session.query(DagBundleModel).get(name)