diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py index 5ba0ffa71594d..1019d0199ec00 100644 --- a/airflow/api_connexion/endpoints/xcom_endpoint.py +++ b/airflow/api_connexion/endpoints/xcom_endpoint.py @@ -34,6 +34,7 @@ from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models import DagRun as DR, XCom from airflow.settings import conf +from airflow.utils.api_migration import mark_fastapi_migration_done from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session from airflow.www.extensions.init_auth_manager import get_auth_manager @@ -83,6 +84,7 @@ def get_xcom_entries( return xcom_collection_schema.dump(XComCollection(xcom_entries=query, total_entries=total_entries)) +@mark_fastapi_migration_done @security.requires_access_dag("GET", DagAccessEntity.XCOM) @provide_session def get_xcom_entry( diff --git a/airflow/api_fastapi/core_api/datamodels/xcom.py b/airflow/api_fastapi/core_api/datamodels/xcom.py new file mode 100644 index 0000000000000..5533d3cdc7a65 --- /dev/null +++ b/airflow/api_fastapi/core_api/datamodels/xcom.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, field_validator + + +class XComResponse(BaseModel): + """Serializer for a xcom item.""" + + key: str + timestamp: datetime + execution_date: datetime + map_index: int + task_id: str + dag_id: str + + +class XComResponseNative(XComResponse): + """XCom response serializer with native return type.""" + + value: Any + + +class XComResponseString(XComResponse): + """XCom response serializer with string return type.""" + + value: Any + + @field_validator("value") + def value_to_string(cls, v): + return str(v) if v is not None else None diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index e938c45945549..1f8831b4eeae0 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3104,6 +3104,99 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}: + get: + tags: + - XCom + summary: Get Xcom Entry + description: Get an XCom entry. + operationId: get_xcom_entry + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: xcom_key + in: path + required: true + schema: + type: string + title: Xcom Key + - name: map_index + in: query + required: false + schema: + type: integer + default: -1 + title: Map Index + - name: deserialize + in: query + required: false + schema: + type: boolean + default: false + title: Deserialize + - name: stringify + in: query + required: false + schema: + type: boolean + default: true + title: Stringify + responses: + '200': + description: Successful Response + content: + application/json: + schema: + anyOf: + - $ref: '#/components/schemas/XComResponseNative' + - $ref: '#/components/schemas/XComResponseString' + title: Response Get Xcom Entry + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' components: schemas: AppBuilderMenuItemResponse: @@ -5134,3 +5227,73 @@ components: - git_version title: VersionInfo description: Version information serializer for responses. + XComResponseNative: + properties: + key: + type: string + title: Key + timestamp: + type: string + format: date-time + title: Timestamp + execution_date: + type: string + format: date-time + title: Execution Date + map_index: + type: integer + title: Map Index + task_id: + type: string + title: Task Id + dag_id: + type: string + title: Dag Id + value: + title: Value + type: object + required: + - key + - timestamp + - execution_date + - map_index + - task_id + - dag_id + - value + title: XComResponseNative + description: XCom response serializer with native return type. + XComResponseString: + properties: + key: + type: string + title: Key + timestamp: + type: string + format: date-time + title: Timestamp + execution_date: + type: string + format: date-time + title: Execution Date + map_index: + type: integer + title: Map Index + task_id: + type: string + title: Task Id + dag_id: + type: string + title: Dag Id + value: + title: Value + type: object + required: + - key + - timestamp + - execution_date + - map_index + - task_id + - dag_id + - value + title: XComResponseString + description: XCom response serializer with string return type. diff --git a/airflow/api_fastapi/core_api/routes/public/__init__.py b/airflow/api_fastapi/core_api/routes/public/__init__.py index b7c8affe4a9cb..d799283a9fb1b 100644 --- a/airflow/api_fastapi/core_api/routes/public/__init__.py +++ b/airflow/api_fastapi/core_api/routes/public/__init__.py @@ -34,6 +34,7 @@ from airflow.api_fastapi.core_api.routes.public.task_instances import task_instances_router from airflow.api_fastapi.core_api.routes.public.variables import variables_router from airflow.api_fastapi.core_api.routes.public.version import version_router +from airflow.api_fastapi.core_api.routes.public.xcom import xcom_router public_router = AirflowRouter(prefix="/public") @@ -56,3 +57,4 @@ public_router.include_router(variables_router) public_router.include_router(version_router) public_router.include_router(dag_stats_router) +public_router.include_router(xcom_router) diff --git a/airflow/api_fastapi/core_api/routes/public/xcom.py b/airflow/api_fastapi/core_api/routes/public/xcom.py new file mode 100644 index 0000000000000..d5e0e46197f25 --- /dev/null +++ b/airflow/api_fastapi/core_api/routes/public/xcom.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import copy + +from fastapi import Depends, HTTPException +from sqlalchemy import and_, select +from sqlalchemy.orm import Session +from typing_extensions import Annotated + +from airflow.api_fastapi.common.db.common import get_session +from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.core_api.datamodels.xcom import ( + XComResponseNative, + XComResponseString, +) +from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.models import DagRun as DR, XCom +from airflow.settings import conf + +xcom_router = AirflowRouter( + tags=["XCom"], prefix="/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries" +) + + +@xcom_router.get( + "/{xcom_key}", + responses=create_openapi_http_exception_doc([400, 401, 403, 404]), +) +def get_xcom_entry( + dag_id: str, + task_id: str, + dag_run_id: str, + xcom_key: str, + session: Annotated[Session, Depends(get_session)], + map_index: int = -1, + deserialize: bool = False, + stringify: bool = True, +) -> XComResponseNative | XComResponseString: + """Get an XCom entry.""" + if deserialize: + if not conf.getboolean("api", "enable_xcom_deserialize_support", fallback=False): + raise HTTPException(400, "XCom deserialization is disabled in configuration.") + query = select(XCom, XCom.value) + else: + query = select(XCom) + print() + + query = query.where( + XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.key == xcom_key, XCom.map_index == map_index + ) + query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id)) + query = query.where(DR.run_id == dag_run_id) + + if deserialize: + item = session.execute(query).one_or_none() + else: + item = session.scalars(query).one_or_none() + + if item is None: + raise HTTPException(404, f"XCom entry with key: `{xcom_key}` not found") + + if deserialize: + xcom, value = item + xcom_stub = copy.copy(xcom) + xcom_stub.value = value + xcom_stub.value = XCom.deserialize_value(xcom_stub) + item = xcom_stub + + if stringify or conf.getboolean("core", "enable_xcom_pickling"): + return XComResponseString.model_validate(item, from_attributes=True) + + return XComResponseNative.model_validate(item, from_attributes=True) diff --git a/tests/api_fastapi/core_api/routes/public/test_xcom.py b/tests/api_fastapi/core_api/routes/public/test_xcom.py new file mode 100644 index 0000000000000..58df4903d4250 --- /dev/null +++ b/tests/api_fastapi/core_api/routes/public/test_xcom.py @@ -0,0 +1,226 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.models import XCom +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance +from airflow.models.xcom import BaseXCom, resolve_xcom_backend +from airflow.operators.empty import EmptyOperator +from airflow.utils import timezone +from airflow.utils.session import provide_session +from airflow.utils.types import DagRunType + +from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom + +pytestmark = pytest.mark.db_test + +TEST_XCOM_KEY = "test_xcom_key" +TEST_XCOM_VALUE = {"key": "value"} +TEST_XCOM_KEY2 = "test_xcom_key_non_serializable" +TEST_XCOM_VALUE2 = {"key": {("201009_NB502104_0421_AHJY23BGXG (SEQ_WF: 138898)", None): 82359}} +TEST_XCOM_KEY3 = "test_xcom_key_non_existing" + +TEST_DAG_ID = "test-dag-id" +TEST_TASK_ID = "test-task-id" +TEST_EXECUTION_DATE = "2005-04-02T00:00:00+00:00" + +execution_date_parsed = timezone.parse(TEST_EXECUTION_DATE) +run_id = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) + + +@provide_session +def _create_xcom(key, value, backend, session=None) -> None: + backend.set( + key=key, + value=value, + dag_id=TEST_DAG_ID, + task_id=TEST_TASK_ID, + run_id=run_id, + session=session, + ) + + +@provide_session +def _create_dag_run(session=None) -> None: + dagrun = DagRun( + dag_id=TEST_DAG_ID, + run_id=run_id, + execution_date=execution_date_parsed, + start_date=execution_date_parsed, + run_type=DagRunType.MANUAL, + ) + session.add(dagrun) + ti = TaskInstance(EmptyOperator(task_id=TEST_TASK_ID), run_id=run_id) + ti.dag_id = TEST_DAG_ID + session.add(ti) + + +class CustomXCom(BaseXCom): + @classmethod + def deserialize_value(cls, xcom: XCom): + return f"real deserialized {super().deserialize_value(xcom)}" + + def orm_deserialize_value(self): + return f"orm deserialized {super().orm_deserialize_value()}" + + +class TestXComEndpoint: + @staticmethod + def clear_db(): + clear_db_dags() + clear_db_runs() + clear_db_xcom() + + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.clear_db() + + def teardown_method(self) -> None: + self.clear_db() + + def create_xcom(self, key, value, backend=XCom) -> None: + _create_dag_run() + _create_xcom(key, value, backend) + + +class TestGetXComEntry(TestXComEndpoint): + def test_should_respond_200_stringify(self, test_client): + self.create_xcom(TEST_XCOM_KEY, TEST_XCOM_VALUE) + response = test_client.get( + f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{TEST_XCOM_KEY}" + ) + assert response.status_code == 200 + + current_data = response.json() + assert current_data == { + "dag_id": TEST_DAG_ID, + "execution_date": execution_date_parsed.strftime("%Y-%m-%dT%H:%M:%SZ"), + "key": TEST_XCOM_KEY, + "task_id": TEST_TASK_ID, + "map_index": -1, + "timestamp": current_data["timestamp"], + "value": str(TEST_XCOM_VALUE), + } + + def test_should_respond_200_native(self, test_client): + self.create_xcom(TEST_XCOM_KEY, TEST_XCOM_VALUE) + response = test_client.get( + f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{TEST_XCOM_KEY}?stringify=false" + ) + assert response.status_code == 200 + + current_data = response.json() + assert current_data == { + "dag_id": TEST_DAG_ID, + "execution_date": execution_date_parsed.strftime("%Y-%m-%dT%H:%M:%SZ"), + "key": TEST_XCOM_KEY, + "task_id": TEST_TASK_ID, + "map_index": -1, + "timestamp": current_data["timestamp"], + "value": TEST_XCOM_VALUE, + } + + @conf_vars({("core", "enable_xcom_pickling"): "True"}) + def test_should_respond_200_pickled(self, test_client): + self.create_xcom(TEST_XCOM_KEY2, TEST_XCOM_VALUE2) + response = test_client.get( + f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{TEST_XCOM_KEY2}" + ) + assert response.status_code == 200 + + current_data = response.json() + assert current_data == { + "dag_id": TEST_DAG_ID, + "execution_date": execution_date_parsed.strftime("%Y-%m-%dT%H:%M:%SZ"), + "key": TEST_XCOM_KEY2, + "task_id": TEST_TASK_ID, + "map_index": -1, + "timestamp": current_data["timestamp"], + "value": str(TEST_XCOM_VALUE2), + } + + def test_should_raise_404_for_non_existent_xcom(self, test_client): + response = test_client.get( + f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{TEST_XCOM_KEY3}" + ) + assert response.status_code == 404 + assert response.json()["detail"] == f"XCom entry with key: `{TEST_XCOM_KEY3}` not found" + + @pytest.mark.parametrize( + "support_deserialize, params, expected_status_or_value", + [ + pytest.param( + True, + {"deserialize": True}, + f"real deserialized {TEST_XCOM_VALUE}", + id="enabled deserialize-true", + ), + pytest.param( + False, + {"deserialize": True}, + 400, + id="disabled deserialize-true", + ), + pytest.param( + True, + {"deserialize": False}, + f"orm deserialized {TEST_XCOM_VALUE}", + id="enabled deserialize-false", + ), + pytest.param( + False, + {"deserialize": False}, + f"orm deserialized {TEST_XCOM_VALUE}", + id="disabled deserialize-false", + ), + pytest.param( + True, + {}, + f"orm deserialized {TEST_XCOM_VALUE}", + id="enabled default", + ), + pytest.param( + False, + {}, + f"orm deserialized {TEST_XCOM_VALUE}", + id="disabled default", + ), + ], + ) + @conf_vars({("core", "xcom_backend"): "tests.api_fastapi.core_api.routes.public.test_xcom.CustomXCom"}) + def test_custom_xcom_deserialize( + self, support_deserialize: bool, params: str, expected_status_or_value: int | str, test_client + ): + XCom = resolve_xcom_backend() + self.create_xcom(TEST_XCOM_KEY, TEST_XCOM_VALUE, backend=XCom) + + url = f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{TEST_XCOM_KEY}" + with mock.patch("airflow.api_fastapi.core_api.routes.public.xcom.XCom", XCom): + with conf_vars({("api", "enable_xcom_deserialize_support"): str(support_deserialize)}): + response = test_client.get(url, params=params) + + if isinstance(expected_status_or_value, int): + assert response.status_code == expected_status_or_value + else: + assert response.status_code == 200 + assert response.json()["value"] == expected_status_or_value