Skip to content

Commit

Permalink
Allow setting run_id in xcom_pull method (apache#41343)
Browse files Browse the repository at this point in the history
  • Loading branch information
fredthomsen authored Nov 12, 2024
1 parent 28d0a7e commit d07ff50
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
11 changes: 10 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ def _xcom_pull(
session: Session = NEW_SESSION,
map_indexes: int | Iterable[int] | None = None,
default: Any = None,
run_id: str | None = None,
) -> Any:
"""
Pull XComs that optionally meet certain criteria.
Expand All @@ -588,6 +589,8 @@ def _xcom_pull(
:param include_prior_dates: If False, only XComs from the current
execution_date are returned. If *True*, XComs from previous dates
are returned as well.
:param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id.
If *None* (default), the run_id of the calling task is used.
When pulling one single task (``task_id`` is *None* or a str) without
specifying ``map_indexes``, the return value is inferred from whether
Expand All @@ -603,10 +606,12 @@ def _xcom_pull(
"""
if dag_id is None:
dag_id = ti.dag_id
if run_id is None:
run_id = ti.run_id

query = XCom.get_many(
key=key,
run_id=ti.run_id,
run_id=run_id,
dag_ids=dag_id,
task_ids=task_ids,
map_indexes=map_indexes,
Expand Down Expand Up @@ -3472,6 +3477,7 @@ def xcom_pull(
*,
map_indexes: int | Iterable[int] | None = None,
default: Any = None,
run_id: str | None = None,
) -> Any:
"""
Pull XComs that optionally meet certain criteria.
Expand All @@ -3491,6 +3497,8 @@ def xcom_pull(
:param include_prior_dates: If False, only XComs from the current
execution_date are returned. If *True*, XComs from previous dates
are returned as well.
:param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id.
If *None* (default), the run_id of the calling task is used.
When pulling one single task (``task_id`` is *None* or a str) without
specifying ``map_indexes``, the return value is inferred from whether
Expand All @@ -3513,6 +3521,7 @@ def xcom_pull(
session=session,
map_indexes=map_indexes,
default=default,
run_id=run_id,
)

@provide_session
Expand Down
34 changes: 34 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,40 @@ def test_xcom_pull_different_execution_date(self, create_task_instance):
# We *should* get a value using 'include_prior_dates'
assert ti.xcom_pull(task_ids="test_xcom", key=key, include_prior_dates=True) == value

def test_xcom_pull_different_run_ids(self, create_task_instance):
"""
tests xcom fetch behavior w/different run ids
"""
key = "xcom_key"
task_id = "test_xcom"
diff_run_id = "diff_run_id"
same_run_id_value = "xcom_value_same_run_id"
diff_run_id_value = "xcom_value_different_run_id"

ti_same_run_id = create_task_instance(
dag_id="test_xcom",
task_id=task_id,
)
ti_same_run_id.run(mark_success=True)
ti_same_run_id.xcom_push(key=key, value=same_run_id_value)

ti_diff_run_id = create_task_instance(
dag_id="test_xcom",
task_id=task_id,
run_id=diff_run_id,
)
ti_diff_run_id.run(mark_success=True)
ti_diff_run_id.xcom_push(key=key, value=diff_run_id_value)

assert (
ti_same_run_id.xcom_pull(run_id=ti_same_run_id.dag_run.run_id, task_ids=task_id, key=key)
== same_run_id_value
)
assert (
ti_same_run_id.xcom_pull(run_id=ti_diff_run_id.dag_run.run_id, task_ids=task_id, key=key)
== diff_run_id_value
)

def test_xcom_push_flag(self, dag_maker):
"""
Tests the option for Operators to push XComs
Expand Down

0 comments on commit d07ff50

Please sign in to comment.