diff --git a/UPDATING.md b/UPDATING.md index e5a08cda4d09c..7942ba41b2033 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -228,6 +228,16 @@ Now that the DAG parser syncs DAG permissions there is no longer a need for manu In addition, the `/refresh` and `/refresh_all` webserver endpoints have also been removed. +### TaskInstances now *require* a DagRun + +Under normal operation every TaskInstance row in the database would have DagRun row too, but it was possible to manually delete the DagRun and Airflow would still schedule the TaskInstances. + +In Airflow 2.2 we have changed this and now there is a database-level foreign key constraint ensuring that every TaskInstance has a DagRun row. + +Before updating to this 2.2 release you will have to manually resolve any inconsistencies (add back DagRun rows, or delete TaskInstances) if you have any "dangling" TaskInstance" rows. + +As part of this change the `clean_tis_without_dagrun_interval` config option under `[scheduler]` section has been removed and has no effect. + ## Airflow 2.1.3 No breaking changes. diff --git a/airflow/api/common/experimental/mark_tasks.py b/airflow/api/common/experimental/mark_tasks.py index 08b7363b8a388..945a9cc4a3596 100644 --- a/airflow/api/common/experimental/mark_tasks.py +++ b/airflow/api/common/experimental/mark_tasks.py @@ -21,6 +21,7 @@ from typing import Iterable from sqlalchemy import or_ +from sqlalchemy.orm import contains_eager from airflow.models.baseoperator import BaseOperator from airflow.models.dagrun import DagRun @@ -148,12 +149,14 @@ def get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates): """Get all tasks of the main dag that will be affected by a state change""" qry_dag = ( session.query(TaskInstance) + .join(TaskInstance.dag_run) .filter( TaskInstance.dag_id == dag.dag_id, - TaskInstance.execution_date.in_(confirmed_dates), + DagRun.execution_date.in_(confirmed_dates), TaskInstance.task_id.in_(task_ids), ) .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)) + .options(contains_eager(TaskInstance.dag_run)) ) return qry_dag diff --git a/airflow/api_connexion/endpoints/log_endpoint.py b/airflow/api_connexion/endpoints/log_endpoint.py index 5fcff8ef4fc47..ddd6cfc3d6f35 100644 --- a/airflow/api_connexion/endpoints/log_endpoint.py +++ b/airflow/api_connexion/endpoints/log_endpoint.py @@ -18,12 +18,13 @@ from flask import Response, current_app, request from itsdangerous.exc import BadSignature from itsdangerous.url_safe import URLSafeSerializer +from sqlalchemy.orm import eagerload from airflow.api_connexion import security from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.schemas.log_schema import LogResponseObject, logs_schema from airflow.exceptions import TaskNotFound -from airflow.models import DagRun +from airflow.models import TaskInstance from airflow.security import permissions from airflow.utils.log.log_reader import TaskLogReader from airflow.utils.session import provide_session @@ -60,15 +61,16 @@ def get_log(session, dag_id, dag_run_id, task_id, task_try_number, full_content= if not task_log_reader.supports_read: raise BadRequest("Task log handler does not support read logs.") - query = session.query(DagRun).filter(DagRun.dag_id == dag_id) - dag_run = query.filter(DagRun.run_id == dag_run_id).first() - if not dag_run: - raise NotFound("DAG Run not found") - - ti = dag_run.get_task_instance(task_id, session) + ti = ( + session.query(TaskInstance) + .filter(TaskInstance.task_id == task_id, TaskInstance.run_id == dag_run_id) + .join(TaskInstance.dag_run) + .options(eagerload(TaskInstance.dag_run)) + .one_or_none() + ) if ti is None: metadata['end_of_log'] = True - raise BadRequest(detail="Task instance did not exist in the DB") + raise NotFound(title="TaskInstance not found") dag = current_app.dag_bag.get_dag(dag_id) if dag: diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index e2a6ce9850eb9..361d29e846295 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -19,6 +19,7 @@ from flask import current_app, request from marshmallow import ValidationError from sqlalchemy import and_, func +from sqlalchemy.orm import eagerload from airflow.api_connexion import security from airflow.api_connexion.exceptions import BadRequest, NotFound @@ -54,15 +55,14 @@ def get_task_instance(dag_id: str, dag_run_id: str, task_id: str, session=None): """Get task instance""" query = ( session.query(TI) - .filter(TI.dag_id == dag_id) - .join(DR, and_(TI.dag_id == DR.dag_id, TI.execution_date == DR.execution_date)) - .filter(DR.run_id == dag_run_id) - .filter(TI.task_id == task_id) + .filter(TI.dag_id == dag_id, DR.run_id == dag_run_id, TI.task_id == task_id) + .join(TI.dag_run) + .options(eagerload(TI.dag_run)) .outerjoin( SlaMiss, and_( SlaMiss.dag_id == TI.dag_id, - SlaMiss.execution_date == TI.execution_date, + SlaMiss.execution_date == DR.execution_date, SlaMiss.task_id == TI.task_id, ), ) @@ -127,13 +127,12 @@ def get_task_instances( session=None, ): """Get list of task instances.""" - base_query = session.query(TI) + base_query = session.query(TI).join(TI.dag_run).options(eagerload(TI.dag_run)) if dag_id != "~": base_query = base_query.filter(TI.dag_id == dag_id) if dag_run_id != "~": - base_query = base_query.join(DR, and_(TI.dag_id == DR.dag_id, TI.execution_date == DR.execution_date)) - base_query = base_query.filter(DR.run_id == dag_run_id) + base_query = base_query.filter(TI.run_id == dag_run_id) base_query = _apply_range_filter( base_query, key=DR.execution_date, @@ -156,7 +155,7 @@ def get_task_instances( and_( SlaMiss.dag_id == TI.dag_id, SlaMiss.task_id == TI.task_id, - SlaMiss.execution_date == TI.execution_date, + SlaMiss.execution_date == DR.execution_date, ), isouter=True, ) @@ -183,12 +182,12 @@ def get_task_instances_batch(session=None): data = task_instance_batch_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) - base_query = session.query(TI) + base_query = session.query(TI).join(TI.dag_run).options(eagerload(TI.dag_run)) base_query = _apply_array_filter(base_query, key=TI.dag_id, values=data["dag_ids"]) base_query = _apply_range_filter( base_query, - key=TI.execution_date, + key=DR.execution_date, value_range=(data["execution_date_gte"], data["execution_date_lte"]), ) base_query = _apply_range_filter( @@ -214,7 +213,7 @@ def get_task_instances_batch(session=None): and_( SlaMiss.dag_id == TI.dag_id, SlaMiss.task_id == TI.task_id, - SlaMiss.execution_date == TI.execution_date, + SlaMiss.execution_date == DR.execution_date, ), isouter=True, ) @@ -254,9 +253,7 @@ def post_clear_task_instances(dag_id: str, session=None): clear_task_instances( task_instances.all(), session, dag=dag, dag_run_state=State.RUNNING if reset_dag_runs else False ) - task_instances = task_instances.join( - DR, and_(DR.dag_id == TI.dag_id, DR.execution_date == TI.execution_date) - ).add_column(DR.run_id) + task_instances = task_instances.join(TI.dag_run).options(eagerload(TI.dag_run)) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=task_instances.all()) ) @@ -303,14 +300,6 @@ def post_set_task_instances_state(dag_id, session): future=data["include_future"], past=data["include_past"], commit=not data["dry_run"], + session=session, ) - execution_dates = {ti.execution_date for ti in tis} - execution_date_to_run_id_map = dict( - session.query(DR.execution_date, DR.run_id).filter( - DR.dag_id == dag_id, DR.execution_date.in_(execution_dates) - ) - ) - tis_with_run_id = [(ti, execution_date_to_run_id_map.get(ti.execution_date)) for ti in tis] - return task_instance_reference_collection_schema.dump( - TaskInstanceReferenceCollection(task_instances=tis_with_run_id) - ) + return task_instance_reference_collection_schema.dump(TaskInstanceReferenceCollection(task_instances=tis)) diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py index 95fc475590c44..89ae9a660363e 100644 --- a/airflow/api_connexion/schemas/task_instance_schema.py +++ b/airflow/api_connexion/schemas/task_instance_schema.py @@ -134,18 +134,10 @@ class TaskInstanceReferenceSchema(Schema): """Schema for the task instance reference schema""" task_id = fields.Str() - dag_run_id = fields.Str() + run_id = fields.Str(data_key="dag_run_id") dag_id = fields.Str() execution_date = fields.DateTime() - def get_attribute(self, obj, attr, default): - """Overwritten marshmallow function""" - task_instance_attr = ['task_id', 'execution_date', 'dag_id'] - if attr in task_instance_attr: - obj = obj[0] # As object is a tuple of task_instance and dag_run_id - return get_value(obj, attr, default) - return obj[1] - class TaskInstanceReferenceCollection(NamedTuple): """List of objects with metadata about taskinstance and dag_run_id""" diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py index dfd8c3c685b27..fffcf68d57541 100644 --- a/airflow/cli/commands/dag_command.py +++ b/airflow/cli/commands/dag_command.py @@ -89,9 +89,11 @@ def dag_backfill(args, dag=None): if args.dry_run: print(f"Dry run of DAG {args.dag_id} on {args.start_date}") + dr = DagRun(dag.dag_id, execution_date=args.start_date) for task in dag.tasks: print(f"Task {task.task_id}") - ti = TaskInstance(task, args.start_date) + ti = TaskInstance(task, run_id=None) + ti.dag_run = dr ti.dry_run() else: if args.reset_dagruns: diff --git a/airflow/cli/commands/kubernetes_command.py b/airflow/cli/commands/kubernetes_command.py index 2660daeb38bfb..d7481e443eb22 100644 --- a/airflow/cli/commands/kubernetes_command.py +++ b/airflow/cli/commands/kubernetes_command.py @@ -26,7 +26,7 @@ from airflow.kubernetes import pod_generator from airflow.kubernetes.kube_client import get_kube_client from airflow.kubernetes.pod_generator import PodGenerator -from airflow.models import TaskInstance +from airflow.models import DagRun, TaskInstance from airflow.settings import pod_mutation_hook from airflow.utils import cli as cli_utils, yaml from airflow.utils.cli import get_dag @@ -38,9 +38,11 @@ def generate_pod_yaml(args): execution_date = args.execution_date dag = get_dag(subdir=args.subdir, dag_id=args.dag_id) yaml_output_path = args.output_path + dr = DagRun(dag.dag_id, execution_date=execution_date) kube_config = KubeConfig() for task in dag.tasks: - ti = TaskInstance(task, execution_date) + ti = TaskInstance(task, None) + ti.dag_run = dr pod = PodGenerator.construct_pod( dag_id=args.dag_id, task_id=ti.task_id, diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 4dd02830d9f3a..91a937404e68b 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -21,15 +21,16 @@ import logging import os import textwrap -from contextlib import contextmanager, redirect_stderr, redirect_stdout +from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress from typing import List from pendulum.parsing.exceptions import ParserError +from sqlalchemy.orm.exc import NoResultFound from airflow import settings from airflow.cli.simple_table import AirflowConsole from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, DagRunNotFound from airflow.executors.executor_loader import ExecutorLoader from airflow.jobs.local_task_job import LocalTaskJob from airflow.models import DagPickle, TaskInstance @@ -51,18 +52,43 @@ from airflow.utils.session import create_session, provide_session -def _get_ti(task, exec_date_or_run_id): +def _get_dag_run(dag, exec_date_or_run_id, create_if_necssary, session): + dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session) + if dag_run: + return dag_run + + execution_date = None + with suppress(ParserError, TypeError): + execution_date = timezone.parse(exec_date_or_run_id) + + if create_if_necssary and not execution_date: + return DagRun(dag_id=dag.dag_id, run_id=exec_date_or_run_id) + try: + return ( + session.query(DagRun) + .filter( + DagRun.dag_id == dag.dag_id, + DagRun.execution_date == execution_date, + ) + .one() + ) + except NoResultFound: + if create_if_necssary: + return DagRun(dag.dag_id, execution_date=execution_date) + raise DagRunNotFound( + f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found" + ) from None + + +@provide_session +def _get_ti(task, exec_date_or_run_id, create_if_necssary=False, session=None): """Get the task instance through DagRun.run_id, if that fails, get the TI the old way""" - dag_run = task.dag.get_dagrun(run_id=exec_date_or_run_id) - if not dag_run: - try: - execution_date = timezone.parse(exec_date_or_run_id) - ti = TaskInstance(task, execution_date) - ti.refresh_from_db() - return ti - except (ParserError, TypeError): - raise AirflowException(f"DagRun with run_id: {exec_date_or_run_id} not found") + dag_run = _get_dag_run(task.dag, exec_date_or_run_id, create_if_necssary, session) + ti = dag_run.get_task_instance(task.task_id) + if not ti and create_if_necssary: + ti = TaskInstance(task, run_id=None) + ti.dag_run = dag_run ti.refresh_from_task(task) return ti @@ -75,11 +101,6 @@ def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None: - as raw task - by executor """ - if args.local and args.raw: - raise AirflowException( - "Option --raw and --local are mutually exclusive. " - "Please remove one option to execute the command." - ) if args.local: _run_task_by_local_task_job(args, ti) elif args.raw: @@ -155,17 +176,6 @@ def _run_task_by_local_task_job(args, ti): def _run_raw_task(args, ti: TaskInstance) -> None: """Runs the main task handling code""" - unsupported_options = [o for o in RAW_TASK_UNSUPPORTED_OPTION if getattr(args, o)] - - if unsupported_options: - raise AirflowException( - "Option --raw does not work with some of the other options on this command. You " - "can't use --raw option and the following options: {}. You provided the option {}. " - "Delete it to execute the command".format( - ", ".join(f"--{o}" for o in RAW_TASK_UNSUPPORTED_OPTION), - ", ".join(f"--{o}" for o in unsupported_options), - ) - ) ti._run_raw_task( mark_success=args.mark_success, job_id=args.job_id, @@ -213,6 +223,27 @@ def _capture_task_logs(ti): def task_run(args, dag=None): """Runs a single task instance""" # Load custom airflow config + + if args.local and args.raw: + raise AirflowException( + "Option --raw and --local are mutually exclusive. " + "Please remove one option to execute the command." + ) + + if args.raw: + unsupported_options = [o for o in RAW_TASK_UNSUPPORTED_OPTION if getattr(args, o)] + + if unsupported_options: + raise AirflowException( + "Option --raw does not work with some of the other options on this command. You " + "can't use --raw option and the following options: {}. You provided the option {}. " + "Delete it to execute the command".format( + ", ".join(f"--{o}" for o in RAW_TASK_UNSUPPORTED_OPTION), + ", ".join(f"--{o}" for o in unsupported_options), + ) + ) + if dag and args.pickle: + raise AirflowException("You cannot use the --pickle option when using DAG.cli() method.") if args.cfg_path: with open(args.cfg_path) as conf_file: conf_dict = json.load(conf_file) @@ -231,9 +262,7 @@ def task_run(args, dag=None): # processing hundreds of simultaneous tasks. settings.configure_orm(disable_connection_pool=True) - if dag and args.pickle: - raise AirflowException("You cannot use the --pickle option when using DAG.cli() method.") - elif args.pickle: + if args.pickle: print(f'Loading pickle id: {args.pickle}') dag = get_dag_by_pickle(args.pickle) elif not dag: @@ -359,14 +388,17 @@ def task_states_for_dag_run(args, session=None): raise AirflowException(f"Error parsing the supplied execution_date. Error: {str(err)}") if dag_run is None: - raise AirflowException("DagRun does not exist.") - tis = dag_run.get_task_instances() + raise DagRunNotFound( + f"DagRun for {args.dag_id} with run_id or execution_date of {args.execution_date_or_run_id!r} " + "not found" + ) + AirflowConsole().print_as( - data=tis, + data=dag_run.task_instances, output=args.output, mapper=lambda ti: { "dag_id": ti.dag_id, - "execution_date": ti.execution_date.isoformat(), + "execution_date": dag_run.execution_date.isoformat(), "task_id": ti.task_id, "state": ti.state, "start_date": ti.start_date.isoformat() if ti.start_date else "", @@ -405,7 +437,7 @@ def task_test(args, dag=None): if args.task_params: passed_in_params = json.loads(args.task_params) task.params.update(passed_in_params) - ti = _get_ti(task, args.execution_date_or_run_id) + ti = _get_ti(task, args.execution_date_or_run_id, create_if_necssary=True) try: if args.dry_run: @@ -431,7 +463,7 @@ def task_render(args): """Renders and displays templated fields for a given task""" dag = get_dag(args.subdir, args.dag_id) task = dag.get_task(task_id=args.task_id) - ti = _get_ti(task, args.execution_date_or_run_id) + ti = _get_ti(task, args.execution_date_or_run_id, create_if_necssary=True) ti.render_templates() for attr in task.__class__.template_fields: print( diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 95e1bcd644791..da353821cbcb6 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1693,14 +1693,6 @@ type: string example: ~ default: "5" - - name: clean_tis_without_dagrun_interval - description: | - How often (in seconds) to check and tidy up 'running' TaskInstancess - that no longer have a matching DagRun - version_added: 2.0.0 - type: float - example: ~ - default: "15.0" - name: scheduler_heartbeat_sec description: | The scheduler constantly tries to trigger new tasks (look at the diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 9bc12528c8ba9..e36a6eb69836e 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -845,10 +845,6 @@ tls_key = # listen (in seconds). job_heartbeat_sec = 5 -# How often (in seconds) to check and tidy up 'running' TaskInstancess -# that no longer have a matching DagRun -clean_tis_without_dagrun_interval = 15.0 - # The scheduler constantly tries to trigger new tasks (look at the # scheduler section in the docs for more information). This defines # how often the scheduler should run (in seconds). diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 55efd8e0bf29f..ce6b9553ec462 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -24,7 +24,7 @@ from contextlib import redirect_stderr, redirect_stdout, suppress from datetime import timedelta from multiprocessing.connection import Connection as MultiprocessingConnection -from typing import List, Optional, Set, Tuple +from typing import Iterator, List, Optional, Set, Tuple from setproctitle import setproctitle from sqlalchemy import func, or_ @@ -49,6 +49,7 @@ from airflow.utils.session import provide_session from airflow.utils.state import State +DR = models.DagRun TI = models.TaskInstance @@ -378,7 +379,8 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: return qry = ( - session.query(TI.task_id, func.max(TI.execution_date).label('max_ti')) + session.query(TI.task_id, func.max(DR.execution_date).label('max_ti')) + .join(TI.dag_run) .with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql') .filter(TI.dag_id == dag.dag_id) .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED)) @@ -387,14 +389,14 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: .subquery('sq') ) - max_tis: List[TI] = ( + max_tis: Iterator[TI] = ( session.query(TI) + .join(TI.dag_run) .filter( TI.dag_id == dag.dag_id, TI.task_id == qry.c.task_id, - TI.execution_date == qry.c.max_ti, + DR.execution_date == qry.c.max_ti, ) - .all() ) ts = timezone.utcnow() @@ -558,7 +560,7 @@ def execute_callbacks( @provide_session def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session): dag = dagbag.dags[request.dag_id] - dag_run = dag.get_dagrun(execution_date=request.execution_date, session=session) + dag_run = dag.get_dagrun(run_id=request.run_id, session=session) dag.handle_callback( dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session ) @@ -570,7 +572,7 @@ def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest): if simple_ti.task_id in dag.task_ids: task = dag.get_task(simple_ti.task_id) if request.is_failure_callback: - ti = TI(task, simple_ti.execution_date) + ti = TI(task, run_id=simple_ti.run_id) # TODO: Use simple_ti to improve performance here in the future ti.refresh_from_db() ti.handle_failure_with_callback(error=request.msg, test_mode=self.UNIT_TEST_MODE) diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index 91656879468c7..993787cc36d30 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -151,7 +151,8 @@ def _run( task_instance_related_annotations = { 'dag_id': annotations['dag_id'], 'task_id': annotations['task_id'], - 'execution_date': annotations['execution_date'], + 'execution_date': annotations.get('execution_date'), + 'run_id': annotations.get('run_id'), 'try_number': annotations['try_number'], } @@ -291,7 +292,7 @@ def run_next(self, next_job: KubernetesJobType) -> None: """ self.log.info('Kubernetes job is %s', str(next_job)) key, command, kube_executor_config, pod_template_file = next_job - dag_id, task_id, execution_date, try_number = key + dag_id, task_id, run_id, try_number = key if command[0:3] != ["airflow", "tasks", "run"]: raise ValueError('The command must start with ["airflow", "tasks", "run"].') @@ -311,7 +312,8 @@ def run_next(self, next_job: KubernetesJobType) -> None: task_id=task_id, kube_image=self.kube_config.kube_image, try_number=try_number, - date=execution_date, + date=None, + run_id=run_id, args=command, pod_override_object=kube_executor_config, base_worker_pod=base_worker_pod, @@ -453,27 +455,34 @@ def clear_not_launched_queued_tasks(self, session=None) -> None: for task in queued_tasks: self.log.debug("Checking task %s", task) - dict_string = "dag_id={},task_id={},execution_date={},airflow-worker={}".format( + dict_string = "dag_id={},task_id={},airflow-worker={}".format( pod_generator.make_safe_label_value(task.dag_id), pod_generator.make_safe_label_value(task.task_id), - pod_generator.datetime_to_label_safe_datestring(task.execution_date), pod_generator.make_safe_label_value(str(self.scheduler_job_id)), ) kwargs = dict(label_selector=dict_string) if self.kube_config.kube_client_request_args: - for key, value in self.kube_config.kube_client_request_args.items(): - kwargs[key] = value + kwargs.update(**self.kube_config.kube_client_request_args) + + # Try run_id first + kwargs['label_selector'] += ',run_id=' + pod_generator.make_safe_label_value(task.run_id) pod_list = self.kube_client.list_namespaced_pod(self.kube_config.kube_namespace, **kwargs) - if not pod_list.items: - self.log.info( - 'TaskInstance: %s found in queued state but was not launched, rescheduling', task - ) - session.query(TaskInstance).filter( - TaskInstance.dag_id == task.dag_id, - TaskInstance.task_id == task.task_id, - TaskInstance.execution_date == task.execution_date, - ).update({TaskInstance.state: State.NONE}) + if pod_list.items: + continue + # Fallback to old style of using execution_date + kwargs['label_selector'] = dict_string + ',exectuion_date={}'.format( + pod_generator.datetime_to_label_safe_datestring(task.execution_date) + ) + pod_list = self.kube_client.list_namespaced_pod(self.kube_config.kube_namespace, **kwargs) + if pod_list.items: + continue + self.log.info('TaskInstance: %s found in queued state but was not launched, rescheduling', task) + session.query(TaskInstance).filter( + TaskInstance.dag_id == task.dag_id, + TaskInstance.task_id == task.task_id, + TaskInstance.run_id == task.run_id, + ).update({TaskInstance.state: State.NONE}) def start(self) -> None: """Starts the executor""" diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index 76d11007481a6..e105eb5118cda 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -22,7 +22,7 @@ from typing import Optional, Set import pendulum -from sqlalchemy import and_ +from sqlalchemy.orm import eagerload from sqlalchemy.orm.session import Session, make_transient from tabulate import tabulate @@ -335,7 +335,6 @@ def _get_dag_run(self, dagrun_info: DagRunInfo, dag: DAG, session: Session = Non # explicitly mark as backfill and running run.state = State.RUNNING - run.run_id = run.generate_run_id(DagRunType.BACKFILL_JOB, run_date) run.run_type = DagRunType.BACKFILL_JOB run.verify_integrity(session=session) return run @@ -434,15 +433,12 @@ def _process_backfill_task_instances( # determined deadlocked while they are actually # waiting for their upstream to finish @provide_session - def _per_task_process(key, ti, session=None): + def _per_task_process(key, ti: TaskInstance, session=None): ti.refresh_from_db(lock_for_update=True, session=session) task = self.dag.get_task(ti.task_id, include_subdags=True) ti.task = task - ignore_depends_on_past = self.ignore_first_depends_on_past and ti.execution_date == ( - start_date or ti.start_date - ) self.log.debug("Task instance to run %s state %s", ti, ti.state) # The task was already marked successful or skipped by a @@ -487,6 +483,12 @@ def _per_task_process(key, ti, session=None): ti_status.running.pop(key) return + if self.ignore_first_depends_on_past: + dagrun = ti.get_dagrun(session=session) + ignore_depends_on_past = dagrun.execution_date == (start_date or ti.start_date) + else: + ignore_depends_on_past = False + backfill_context = DepContext( deps=BACKFILL_QUEUED_DEPS, ignore_depends_on_past=ignore_depends_on_past, @@ -580,6 +582,7 @@ def _per_task_process(key, ti, session=None): num_running_task_instances_in_dag = DAG.get_num_task_instances( self.dag_id, states=self.STATES_COUNT_AS_RUNNING, + session=session, ) if num_running_task_instances_in_dag >= self.dag.max_active_tasks: @@ -592,6 +595,7 @@ def _per_task_process(key, ti, session=None): dag_id=self.dag_id, task_ids=[task.task_id], states=self.STATES_COUNT_AS_RUNNING, + session=session, ) if num_running_task_instances_in_task >= task.max_active_tis_per_dag: @@ -645,17 +649,15 @@ def tabulate_ti_keys_set(set_ti_keys: Set[TaskInstanceKey]) -> str: # Sorting by execution date first sorted_ti_keys = sorted( set_ti_keys, - key=lambda ti_key: (ti_key.execution_date, ti_key.dag_id, ti_key.task_id, ti_key.try_number), + key=lambda ti_key: (ti_key.run_id, ti_key.dag_id, ti_key.task_id, ti_key.try_number), ) - return tabulate(sorted_ti_keys, headers=["DAG ID", "Task ID", "Execution date", "Try number"]) + return tabulate(sorted_ti_keys, headers=["DAG ID", "Task ID", "Run ID", "Try number"]) def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str: # Sorting by execution date first - sorted_tis = sorted( - set_tis, key=lambda ti: (ti.execution_date, ti.dag_id, ti.task_id, ti.try_number) - ) - tis_values = ((ti.dag_id, ti.task_id, ti.execution_date, ti.try_number) for ti in sorted_tis) - return tabulate(tis_values, headers=["DAG ID", "Task ID", "Execution date", "Try number"]) + sorted_tis = sorted(set_tis, key=lambda ti: (ti.run_id, ti.dag_id, ti.task_id, ti.try_number)) + tis_values = ((ti.dag_id, ti.task_id, ti.run_id, ti.try_number) for ti in sorted_tis) + return tabulate(tis_values, headers=["DAG ID", "Task ID", "Run ID", "Try number"]) err = '' if ti_status.failed: @@ -861,17 +863,13 @@ def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None): # also consider running as the state might not have changed in the db yet running_tis = self.executor.running + # Can't use an update here since it doesn't support joins. resettable_states = [State.SCHEDULED, State.QUEUED] if filter_by_dag_run is None: resettable_tis = ( session.query(TaskInstance) - .join( - DagRun, - and_( - TaskInstance.dag_id == DagRun.dag_id, - TaskInstance.execution_date == DagRun.execution_date, - ), - ) + .join(TaskInstance.dag_run) + .options(eagerload(TaskInstance.dag_run)) .filter( DagRun.state == State.RUNNING, DagRun.run_type != DagRunType.BACKFILL_JOB, @@ -880,12 +878,8 @@ def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None): ).all() else: resettable_tis = filter_by_dag_run.get_task_instances(state=resettable_states, session=session) - tis_to_reset = [] - # Can't use an update here since it doesn't support joins - for ti in resettable_tis: - if ti.key not in queued_tis and ti.key not in running_tis: - tis_to_reset.append(ti) + tis_to_reset = [ti for ti in resettable_tis if ti.key not in queued_tis and ti.key not in running_tis] if not tis_to_reset: return 0 @@ -910,7 +904,7 @@ def query(result, items): reset_tis = helpers.reduce_in_chunks(query, tis_to_reset, [], self.max_tis_per_query) task_instance_str = '\n\t'.join(repr(x) for x in reset_tis) - session.commit() + session.flush() self.log.info("Reset the following %s TaskInstances:\n\t%s", len(reset_tis), task_instance_str) return len(reset_tis) diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index 203a7a82b20b3..7878216c7a9ba 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -227,7 +227,7 @@ def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: dag_run = with_row_locks( session.query(DagRun).filter_by( dag_id=self.dag_id, - execution_date=self.task_instance.execution_date, + run_id=self.task_instance.run_id, ), session=session, ).one() diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index add07d96a2310..d7239a8e9bb1b 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -30,7 +30,7 @@ from sqlalchemy import and_, func, not_, or_, tuple_ from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import load_only, selectinload +from sqlalchemy.orm import eagerload, load_only, selectinload from sqlalchemy.orm.session import Session, make_transient from airflow import models, settings @@ -189,80 +189,6 @@ def is_alive(self, grace_multiplier: Optional[float] = None) -> bool: and (timezone.utcnow() - self.latest_heartbeat).total_seconds() < scheduler_health_check_threshold ) - @provide_session - def _change_state_for_tis_without_dagrun( - self, old_states: List[TaskInstanceState], new_state: TaskInstanceState, session: Session = None - ) -> None: - """ - For all DAG IDs in the DagBag, look for task instances in the - old_states and set them to new_state if the corresponding DagRun - does not exist or exists but is not in the running or queued state. This - normally should not happen, but it can if the state of DagRuns are - changed manually. - - :param old_states: examine TaskInstances in this state - :type old_states: list[airflow.utils.state.State] - :param new_state: set TaskInstances to this state - :type new_state: airflow.utils.state.State - """ - tis_changed = 0 - query = ( - session.query(models.TaskInstance) - .outerjoin(models.TaskInstance.dag_run) - .filter(models.TaskInstance.dag_id.in_(list(self.dagbag.dag_ids))) - .filter(models.TaskInstance.state.in_(old_states)) - .filter( - or_( - models.DagRun.state.notin_([State.RUNNING, State.QUEUED]), - models.DagRun.state.is_(None), - ) - ) - ) - # We need to do this for mysql as well because it can cause deadlocks - # as discussed in https://issues.apache.org/jira/browse/AIRFLOW-2516 - if self.using_sqlite or self.using_mysql: - tis_to_change: List[TI] = with_row_locks( - query, of=TI, session=session, **skip_locked(session=session) - ).all() - for ti in tis_to_change: - ti.set_state(new_state, session=session) - tis_changed += 1 - else: - subq = query.subquery() - current_time = timezone.utcnow() - ti_prop_update = { - models.TaskInstance.state: new_state, - models.TaskInstance.start_date: current_time, - } - - # Only add end_date and duration if the new_state is 'success', 'failed' or 'skipped' - if new_state in State.finished: - ti_prop_update.update( - { - models.TaskInstance.end_date: current_time, - models.TaskInstance.duration: 0, - } - ) - - tis_changed = ( - session.query(models.TaskInstance) - .filter( - models.TaskInstance.dag_id == subq.c.dag_id, - models.TaskInstance.task_id == subq.c.task_id, - models.TaskInstance.execution_date == subq.c.execution_date, - ) - .update(ti_prop_update, synchronize_session=False) - ) - - if tis_changed > 0: - session.flush() - self.log.warning( - "Set %s task instances to state=%s as their associated DagRun was not in RUNNING state", - tis_changed, - new_state, - ) - Stats.gauge('scheduler.tasks.without_dagrun', tis_changed) - @provide_session def __get_concurrency_maps( self, states: List[TaskInstanceState], session: Session = None @@ -320,14 +246,14 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = # and the dag is not paused query = ( session.query(TI) - .outerjoin(TI.dag_run) - .filter(or_(DR.run_id.is_(None), DR.run_type != DagRunType.BACKFILL_JOB)) - .filter(or_(DR.state.is_(None), DR.state != DagRunState.QUEUED)) + .join(TI.dag_run) + .options(eagerload(TI.dag_run)) + .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state != DagRunState.QUEUED) .join(TI.dag_model) .filter(not_(DM.is_paused)) .filter(TI.state == State.SCHEDULED) .options(selectinload('dag_model')) - .order_by(-TI.priority_weight, TI.execution_date) + .order_by(-TI.priority_weight, DR.execution_date) ) starved_pools = [pool_name for pool_name, stats in pools.items() if stats['open'] <= 0] if starved_pools: @@ -559,11 +485,10 @@ def _process_executor_events(self, session: Session = None) -> int: ti_primary_key_to_try_number_map[ti_key.primary] = ti_key.try_number self.log.info( - "Executor reports execution of %s.%s execution_date=%s " - "exited with status %s for try_number %s", + "Executor reports execution of %s.%s run_id=%s exited with status %s for try_number %s", ti_key.dag_id, ti_key.task_id, - ti_key.execution_date, + ti_key.run_id, state, ti_key.try_number, ) @@ -710,11 +635,6 @@ def _run_scheduler_loop(self) -> None: self._emit_pool_metrics, ) - timers.call_regular_interval( - conf.getfloat('scheduler', 'clean_tis_without_dagrun_interval', fallback=15.0), - self._clean_tis_without_dagrun, - ) - for loop_count in itertools.count(start=1): with Stats.timer() as timer: @@ -765,35 +685,6 @@ def _run_scheduler_loop(self) -> None: ) break - @provide_session - def _clean_tis_without_dagrun(self, session): - with prohibit_commit(session) as guard: - try: - self._change_state_for_tis_without_dagrun( - old_states=[State.UP_FOR_RETRY], new_state=State.FAILED, session=session - ) - - self._change_state_for_tis_without_dagrun( - old_states=[ - State.QUEUED, - State.SCHEDULED, - State.UP_FOR_RESCHEDULE, - State.SENSING, - State.DEFERRED, - ], - new_state=State.NONE, - session=session, - ) - - guard.commit() - except OperationalError as e: - if is_lock_not_available_error(error=e): - self.log.debug("Lock held by another Scheduler") - session.rollback() - else: - raise - guard.commit() - def _do_scheduling(self, session) -> int: """ This function is where the main scheduling decisions take places. It: @@ -1052,7 +943,7 @@ def _schedule_dag_run( callback_to_execute = DagCallbackRequest( full_filepath=dag.fileloc, dag_id=dag.dag_id, - execution_date=dag_run.execution_date, + run_id=dag_run.run_id, is_failure_callback=True, msg='timed_out', ) @@ -1182,7 +1073,7 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None): DagRun.run_type != DagRunType.BACKFILL_JOB, DagRun.state == State.RUNNING, ) - .options(load_only(TI.dag_id, TI.task_id, TI.execution_date)) + .options(load_only(TI.dag_id, TI.task_id, TI.run_id)) ) # Lock these rows, so that another scheduler can't try and adopt these too diff --git a/airflow/kubernetes/kubernetes_helper_functions.py b/airflow/kubernetes/kubernetes_helper_functions.py index fd740ac7e1cc6..bc68daf404149 100644 --- a/airflow/kubernetes/kubernetes_helper_functions.py +++ b/airflow/kubernetes/kubernetes_helper_functions.py @@ -18,7 +18,7 @@ import logging from typing import Dict, Optional -from dateutil import parser +import pendulum from slugify import slugify from airflow.models.taskinstance import TaskInstanceKey @@ -62,6 +62,26 @@ def annotations_to_key(annotations: Dict[str, str]) -> Optional[TaskInstanceKey] dag_id = annotations['dag_id'] task_id = annotations['task_id'] try_number = int(annotations['try_number']) - execution_date = parser.parse(annotations['execution_date']) + run_id = annotations.get('run_id') + if not run_id and 'execution_date' in annotations: + # Compat: Look up the run_id from the TI table! + from airflow.models.dagrun import DagRun + from airflow.models.taskinstance import TaskInstance + from airflow.settings import Session - return TaskInstanceKey(dag_id, task_id, execution_date, try_number) + execution_date = pendulum.parse(annotations['execution_date']) + # Do _not_ use create-session, we don't want to expunge + session = Session() + + run_id: str = ( + session.query(TaskInstance.run_id) + .join(TaskInstance.dag_run) + .filter( + TaskInstance.dag_id == dag_id, + TaskInstance.task_id == task_id, + DagRun.execution_date == execution_date, + ) + .scalar() + ) + + return TaskInstanceKey(dag_id, task_id, run_id, try_number) diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index c7611f62dd8a1..f99726417f851 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -332,12 +332,13 @@ def construct_pod( pod_id: str, try_number: int, kube_image: str, - date: datetime.datetime, + date: Optional[datetime.datetime], args: List[str], pod_override_object: Optional[k8s.V1Pod], base_worker_pod: k8s.V1Pod, namespace: str, scheduler_job_id: int, + run_id: Optional[str] = None, ) -> k8s.V1Pod: """ Construct a pod by gathering and consolidating the configuration from 3 places: @@ -352,25 +353,32 @@ def construct_pod( except Exception: image = kube_image + annotations = { + 'dag_id': dag_id, + 'task_id': task_id, + 'try_number': str(try_number), + } + labels = { + 'airflow-worker': make_safe_label_value(str(scheduler_job_id)), + 'dag_id': make_safe_label_value(dag_id), + 'task_id': make_safe_label_value(task_id), + 'try_number': str(try_number), + 'airflow_version': airflow_version.replace('+', '-'), + 'kubernetes_executor': 'True', + } + if date: + annotations['execution_date'] = date.isoformat() + labels['execution_date'] = datetime_to_label_safe_datestring(date) + if run_id: + annotations['run_id'] = run_id + labels['run_id'] = make_safe_label_value(run_id) + dynamic_pod = k8s.V1Pod( metadata=k8s.V1ObjectMeta( namespace=namespace, - annotations={ - 'dag_id': dag_id, - 'task_id': task_id, - 'execution_date': date.isoformat(), - 'try_number': str(try_number), - }, + annotations=annotations, name=PodGenerator.make_unique_pod_id(pod_id), - labels={ - 'airflow-worker': make_safe_label_value(str(scheduler_job_id)), - 'dag_id': make_safe_label_value(dag_id), - 'task_id': make_safe_label_value(task_id), - 'execution_date': datetime_to_label_safe_datestring(date), - 'try_number': str(try_number), - 'airflow_version': airflow_version.replace('+', '-'), - 'kubernetes_executor': 'True', - }, + labels=labels, ), spec=k8s.V1PodSpec( containers=[ diff --git a/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py b/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py new file mode 100644 index 0000000000000..8c62101b39f4b --- /dev/null +++ b/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py @@ -0,0 +1,287 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""TaskInstance keyed to DagRun + +Revision ID: 7b2661a43ba3 +Revises: 142555e44c17 +Create Date: 2021-07-15 15:26:12.710749 + +""" + +from collections import defaultdict + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.sql import and_, column, select, table + +from airflow.models.base import COLLATION_ARGS + +ID_LEN = 250 + +# revision identifiers, used by Alembic. +revision = '7b2661a43ba3' +down_revision = '142555e44c17' +branch_labels = None +depends_on = None + + +def _mssql_datetime(): + from sqlalchemy.dialects import mssql + + return mssql.DATETIME2(precision=6) + + +# Just Enough Table to run the conditions for update. +task_instance = table( + 'task_instance', + column('task_id', sa.String), + column('dag_id', sa.String), + column('run_id', sa.String), + column('execution_date', sa.TIMESTAMP), +) +task_reschedule = table( + 'task_reschedule', + column('task_id', sa.String), + column('dag_id', sa.String), + column('run_id', sa.String), + column('execution_date', sa.TIMESTAMP), +) +dag_run = table( + 'dag_run', + column('dag_id', sa.String), + column('run_id', sa.String), + column('execution_date', sa.TIMESTAMP), +) + + +def get_table_constraints(conn, table_name): + """ + This function return primary and unique constraint + along with column name. Some tables like `task_instance` + is missing the primary key constraint name and the name is + auto-generated by the SQL server. so this function helps to + retrieve any primary or unique constraint name. + :param conn: sql connection object + :param table_name: table name + :return: a dictionary of ((constraint name, constraint type), column name) of table + :rtype: defaultdict(list) + """ + query = """SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc + JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE AS ccu ON ccu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME + WHERE tc.TABLE_NAME = '{table_name}' AND + (tc.CONSTRAINT_TYPE = 'PRIMARY KEY' or UPPER(tc.CONSTRAINT_TYPE) = 'UNIQUE') + """.format( + table_name=table_name + ) + result = conn.execute(query).fetchall() + constraint_dict = defaultdict(lambda: defaultdict(list)) + for constraint, constraint_type, col_name in result: + constraint_dict[constraint_type][constraint].append(col_name) + return constraint_dict + + +def upgrade(): + """Apply TaskInstance keyed to DagRun""" + conn = op.get_bind() + dialect_name = conn.dialect.name + + run_id_col_type = sa.String(length=ID_LEN, **COLLATION_ARGS) + + if dialect_name == 'sqlite': + naming_convention = { + "uq": "%(table_name)s_%(column_0_N_name)s_key", + } + with op.batch_alter_table('dag_run', naming_convention=naming_convention, recreate="always"): + # The naming_convention force the previously un-named UNIQUE constraints to have the right name -- + # but we still need to enter the context manager to trigger it + pass + elif dialect_name == 'mysql': + with op.batch_alter_table('dag_run') as batch_op: + batch_op.alter_column('dag_id', existing_type=sa.String(length=ID_LEN), type_=run_id_col_type) + batch_op.alter_column('run_id', existing_type=sa.String(length=ID_LEN), type_=run_id_col_type) + batch_op.drop_constraint('dag_id', 'unique') + batch_op.drop_constraint('dag_id_2', 'unique') + batch_op.create_unique_constraint( + 'dag_run_dag_id_execution_date_key', ['dag_id', 'execution_date'] + ) + batch_op.create_unique_constraint('dag_run_dag_id_run_id_key', ['dag_id', 'run_id']) + elif dialect_name == 'mssql': + + # _Somehow_ mssql was missing these constraints entirely! + with op.batch_alter_table('dag_run') as batch_op: + batch_op.create_unique_constraint( + 'dag_run_dag_id_execution_date_key', ['dag_id', 'execution_date'] + ) + batch_op.create_unique_constraint('dag_run_dag_id_run_id_key', ['dag_id', 'run_id']) + + # First create column nullable + op.add_column('task_instance', sa.Column('run_id', type_=run_id_col_type, nullable=True)) + op.add_column('task_reschedule', sa.Column('run_id', type_=run_id_col_type, nullable=True)) + + # Then update the new column by selecting the right value from DagRun + update_query = _multi_table_update(dialect_name, task_instance, task_instance.c.run_id) + op.execute(update_query) + + # + # TaskReschedule has a FK to TaskInstance, so we have to update that before + # we can drop the TI.execution_date column + + update_query = _multi_table_update(dialect_name, task_reschedule, task_reschedule.c.run_id) + op.execute(update_query) + + with op.batch_alter_table('task_reschedule', schema=None) as batch_op: + batch_op.alter_column('run_id', existing_type=run_id_col_type, existing_nullable=True, nullable=False) + + batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', 'foreignkey') + if dialect_name == "mysql": + # Mysql creates an index and a constraint -- we have to drop both + batch_op.drop_index('task_reschedule_dag_task_date_fkey') + batch_op.drop_index('idx_task_reschedule_dag_task_date') + + with op.batch_alter_table('task_instance', schema=None) as batch_op: + # Then make it non-nullable + batch_op.alter_column('run_id', existing_type=run_id_col_type, existing_nullable=True, nullable=False) + + # TODO: Is this right for non-postgres? + if dialect_name == 'mssql': + constraints = get_table_constraints(conn, "task_instance") + pk, _ = constraints['PRIMARY KEY'].popitem() + batch_op.drop_constraint(pk, type_='primary') + elif dialect_name not in ('sqlite'): + batch_op.drop_constraint('task_instance_pkey', type_='primary') + batch_op.create_primary_key('task_instance_pkey', ['dag_id', 'task_id', 'run_id']) + + batch_op.drop_index('ti_dag_date') + batch_op.drop_index('ti_state_lkp') + batch_op.drop_column('execution_date') + batch_op.create_foreign_key( + 'task_instance_dag_run_fkey', + 'dag_run', + ['dag_id', 'run_id'], + ['dag_id', 'run_id'], + ondelete='CASCADE', + ) + + batch_op.create_index('ti_dag_run', ['dag_id', 'run_id']) + batch_op.create_index('ti_state_lkp', ['dag_id', 'task_id', 'run_id', 'state']) + + with op.batch_alter_table('task_reschedule', schema=None) as batch_op: + batch_op.drop_column('execution_date') + batch_op.create_index( + 'idx_task_reschedule_dag_task_run', + ['dag_id', 'task_id', 'run_id'], + unique=False, + ) + # _Now_ there is a unique constraint on the columns in TI we can re-create the FK from TaskReschedule + batch_op.create_foreign_key( + 'task_reschedule_ti_fkey', + 'task_instance', + ['dag_id', 'task_id', 'run_id'], + ['dag_id', 'task_id', 'run_id'], + ondelete='CASCADE', + ) + + # https://docs.microsoft.com/en-us/sql/relational-databases/errors-events/mssqlserver-1785-database-engine-error?view=sql-server-ver15 + ondelete = 'CASCADE' if dialect_name != 'mssql' else 'NO ACTION' + batch_op.create_foreign_key( + 'task_reschedule_dr_fkey', + 'dag_run', + ['dag_id', 'run_id'], + ['dag_id', 'run_id'], + ondelete=ondelete, + ) + + +def downgrade(): + """Unapply TaskInstance keyed to DagRun""" + dialect_name = op.get_bind().dialect.name + + if dialect_name == "mssql": + col_type = _mssql_datetime() + else: + col_type = sa.TIMESTAMP(timezone=True) + + op.add_column('task_instance', sa.Column('execution_date', col_type, nullable=True)) + op.add_column('task_reschedule', sa.Column('execution_date', col_type, nullable=True)) + + update_query = _multi_table_update(dialect_name, task_instance, task_instance.c.execution_date) + op.execute(update_query) + + update_query = _multi_table_update(dialect_name, task_reschedule, task_reschedule.c.execution_date) + op.execute(update_query) + + with op.batch_alter_table('task_reschedule', schema=None) as batch_op: + batch_op.alter_column( + 'execution_date', existing_type=col_type, existing_nullable=True, nullable=False + ) + + # Can't drop PK index while there is a FK referencing it + batch_op.drop_constraint('task_reschedule_ti_fkey') + batch_op.drop_constraint('task_reschedule_dr_fkey') + batch_op.drop_index('idx_task_reschedule_dag_task_run') + + with op.batch_alter_table('task_instance', schema=None) as batch_op: + batch_op.alter_column( + 'execution_date', existing_type=col_type, existing_nullable=True, nullable=False + ) + + batch_op.drop_constraint('task_instance_pkey', type_='primary') + batch_op.create_primary_key('task_instance_pkey', ['dag_id', 'task_id', 'execution_date']) + + batch_op.drop_constraint('task_instance_dag_run_fkey', type_='foreignkey') + batch_op.drop_index('ti_dag_run') + batch_op.drop_index('ti_state_lkp') + batch_op.create_index('ti_state_lkp', ['dag_id', 'task_id', 'execution_date', 'state']) + batch_op.create_index('ti_dag_date', ['dag_id', 'execution_date'], unique=False) + + batch_op.drop_column('run_id') + + with op.batch_alter_table('task_reschedule', schema=None) as batch_op: + batch_op.drop_column('run_id') + batch_op.create_index( + 'idx_task_reschedule_dag_task_date', + ['dag_id', 'task_id', 'execution_date'], + unique=False, + ) + # Can only create FK once there is an index on these columns + batch_op.create_foreign_key( + 'task_reschedule_dag_task_date_fkey', + 'task_instance', + ['dag_id', 'task_id', 'execution_date'], + ['dag_id', 'task_id', 'execution_date'], + ondelete='CASCADE', + ) + + +def _multi_table_update(dialect_name, target, column): + condition = dag_run.c.dag_id == target.c.dag_id + if column == target.c.run_id: + condition = and_(condition, dag_run.c.execution_date == target.c.execution_date) + else: + condition = and_(condition, dag_run.c.run_id == target.c.run_id) + + if dialect_name == "sqlite": + # Most SQLite versions don't support multi table update (and SQLA doesn't know about it anyway), so we + # need to do a Correlated subquery update + sub_q = select([dag_run.c[column.name]]).where(condition) + + return target.update().values({column: sub_q}) + else: + return target.update().where(condition).values({column: dag_run.c[column.name]}) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 60d6621fe583b..96dcbd883b6ff 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -48,6 +48,7 @@ import jinja2 from dateutil.relativedelta import relativedelta from sqlalchemy.orm import Session +from sqlalchemy.orm.exc import NoResultFound import airflow.templates from airflow.compat.functools import cached_property @@ -1284,6 +1285,7 @@ def get_flat_relatives(self, upstream: bool = False): dag: DAG = self._dag return list(map(lambda task_id: dag.task_dict[task_id], self.get_flat_relative_ids(upstream))) + @provide_session def run( self, start_date: Optional[datetime] = None, @@ -1291,17 +1293,48 @@ def run( ignore_first_depends_on_past: bool = True, ignore_ti_state: bool = False, mark_success: bool = False, + test_mode: bool = False, + session: Session = None, ) -> None: """Run a set of task instances for a date range.""" + from airflow.models import DagRun + from airflow.utils.types import DagRunType + start_date = start_date or self.start_date end_date = end_date or self.end_date or timezone.utcnow() for info in self.dag.iter_dagrun_infos_between(start_date, end_date, align=False): ignore_depends_on_past = info.logical_date == start_date and ignore_first_depends_on_past - TaskInstance(self, info.logical_date).run( + try: + dag_run = ( + session.query(DagRun) + .filter( + DagRun.dag_id == self.dag_id, + DagRun.execution_date == info.logical_date, + ) + .one() + ) + ti = TaskInstance(self, run_id=dag_run.run_id) + except NoResultFound: + # This is _mostly_ only used in tests + dr = DagRun( + dag_id=self.dag_id, + run_id=DagRun.generate_run_id(DagRunType.MANUAL, info.logical_date), + run_type=DagRunType.MANUAL, + execution_date=info.logical_date, + data_interval=info.data_interval, + ) + ti = TaskInstance(self, run_id=None) + ti.dag_run = dr + session.add(dr) + session.flush() + + ti.run( mark_success=mark_success, ignore_depends_on_past=ignore_depends_on_past, ignore_ti_state=ignore_ti_state, + test_mode=test_mode, + session=session, ) def dry_run(self) -> None: diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 21f3d4eaf3a88..7743399cc56b3 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -949,7 +949,7 @@ def handle_callback(self, dagrun, success=True, reason=None, session=None): callback = self.on_success_callback if success else self.on_failure_callback if callback: self.log.info('Executing dag callback function: %s', callback) - tis = dagrun.get_task_instances() + tis = dagrun.get_task_instances(session=session) ti = tis[-1] # get first TaskInstance of DagRun ti.task = self.get_task(ti.task_id) context = ti.get_template_context(session=session) @@ -1163,6 +1163,7 @@ def get_task_instances( task_ids=None, start_date=start_date, end_date=end_date, + run_id=None, state=state, include_subdags=False, include_parentdag=False, @@ -1171,7 +1172,8 @@ def get_task_instances( as_pk_tuple=False, session=session, ) - .order_by(TaskInstance.execution_date) + .join(TaskInstance.dag_run) + .order_by(DagRun.execution_date) .all() ) @@ -1182,6 +1184,7 @@ def _get_task_instances( task_ids, start_date: Optional[datetime], end_date: Optional[datetime], + run_id: None, state: Union[str, List[str]], include_subdags: bool, include_parentdag: bool, @@ -1203,6 +1206,7 @@ def _get_task_instances( task_ids, start_date: Optional[datetime], end_date: Optional[datetime], + run_id: Optional[str], state: Union[str, List[str]], include_subdags: bool, include_parentdag: bool, @@ -1223,6 +1227,7 @@ def _get_task_instances( task_ids, start_date: Optional[datetime], end_date: Optional[datetime], + run_id: Optional[str], state: Union[str, List[str]], include_subdags: bool, include_parentdag: bool, @@ -1247,9 +1252,10 @@ def _get_task_instances( # Do we want full objects, or just the primary columns? if as_pk_tuple: - tis = session.query(TI.dag_id, TI.task_id, TI.execution_date) + tis = session.query(TI.dag_id, TI.task_id, TI.run_id) else: tis = session.query(TaskInstance) + tis = tis.join(TaskInstance.dag_run) if include_subdags: # Crafting the right filter for dag_id and task_ids combo @@ -1261,15 +1267,17 @@ def _get_task_instances( tis = tis.filter(or_(*conditions)) else: tis = tis.filter(TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(self.task_ids)) + if run_id: + tis = tis.filter(TaskInstance.run_id == run_id) if start_date: - tis = tis.filter(TaskInstance.execution_date >= start_date) + tis = tis.filter(DagRun.execution_date >= start_date) if task_ids: tis = tis.filter(TaskInstance.task_id.in_(task_ids)) # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC if end_date or not self.allow_future_exec_dates: end_date = end_date or timezone.utcnow() - tis = tis.filter(TaskInstance.execution_date <= end_date) + tis = tis.filter(DagRun.execution_date <= end_date) if state: if isinstance(state, str): @@ -1301,6 +1309,7 @@ def _get_task_instances( task_ids=task_ids, start_date=start_date, end_date=end_date, + run_id=None, state=state, include_subdags=include_subdags, include_parentdag=False, @@ -1353,10 +1362,14 @@ def _get_task_instances( ) ) ti.render_templates() - external_tis = session.query(TI).filter( - TI.dag_id == task.external_dag_id, - TI.task_id == task.external_task_id, - TI.execution_date == pendulum.parse(task.execution_date), + external_tis = ( + session.query(TI) + .join(TI.dag_run) + .filter( + TI.dag_id == task.external_dag_id, + TI.task_id == task.external_task_id, + DagRun.execution_date == pendulum.parse(task.execution_date), + ) ) for tii in external_tis: @@ -1373,8 +1386,9 @@ def _get_task_instances( result.update( downstream._get_task_instances( task_ids=None, - start_date=tii.execution_date, - end_date=tii.execution_date, + run_id=tii.run_id, + start_date=None, + end_date=None, state=state, include_subdags=include_subdags, include_dependent_dags=include_dependent_dags, @@ -1408,7 +1422,7 @@ def _get_task_instances( return result elif result: # We've been asked for objects, lets combine it all back in to a result set - tis = tis.with_entities(TI.dag_id, TI.task_id, TI.execution_date) + tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id) tis = session.query(TI).filter(TI.filter_for_tis(result)) elif exclude_task_ids: @@ -1667,6 +1681,7 @@ def clear( task_ids=task_ids, start_date=start_date, end_date=end_date, + run_id=None, state=state, include_subdags=include_subdags, include_parentdag=include_parentdag, @@ -2267,7 +2282,7 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None): orm_dag.tags.append(dag_tag_orm) session.add(dag_tag_orm) - DagCode.bulk_sync_to_db(filelocs) + DagCode.bulk_sync_to_db(filelocs, session=session) # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller # decide when to commit diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 797943441d42c..e8f5f9835d25c 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -15,20 +15,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import warnings from datetime import datetime from typing import TYPE_CHECKING, Any, Iterable, List, NamedTuple, Optional, Tuple, Union from sqlalchemy import Boolean, Column, Index, Integer, PickleType, String, UniqueConstraint, and_, func, or_ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.orm import backref, relationship, synonym +from sqlalchemy.orm import joinedload, relationship, synonym from sqlalchemy.orm.session import Session from sqlalchemy.sql import expression from airflow import settings from airflow.configuration import conf as airflow_conf from airflow.exceptions import AirflowException, TaskNotFound -from airflow.models.base import ID_LEN, Base +from airflow.models.base import COLLATION_ARGS, ID_LEN, Base from airflow.models.taskinstance import TaskInstance as TI from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext @@ -65,13 +66,13 @@ class DagRun(Base, LoggingMixin): __NO_VALUE = object() id = Column(Integer, primary_key=True) - dag_id = Column(String(ID_LEN)) + dag_id = Column(String(ID_LEN, **COLLATION_ARGS)) queued_at = Column(UtcDateTime) execution_date = Column(UtcDateTime, default=timezone.utcnow) start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) _state = Column('state', String(50), default=State.QUEUED) - run_id = Column(String(ID_LEN)) + run_id = Column(String(ID_LEN, **COLLATION_ARGS)) creating_job_id = Column(Integer) external_trigger = Column(Boolean, default=True) run_type = Column(String(50), nullable=False) @@ -87,17 +88,12 @@ class DagRun(Base, LoggingMixin): __table_args__ = ( Index('dag_id_state', dag_id, _state), - UniqueConstraint('dag_id', 'execution_date'), - UniqueConstraint('dag_id', 'run_id'), + UniqueConstraint('dag_id', 'execution_date', name='dag_run_dag_id_execution_date_key'), + UniqueConstraint('dag_id', 'run_id', name='dag_run_dag_id_run_id_key'), Index('idx_last_scheduling_decision', last_scheduling_decision), ) - task_instances = relationship( - TI, - primaryjoin=and_(TI.dag_id == dag_id, TI.execution_date == execution_date), - foreign_keys=(dag_id, execution_date), - backref=backref('dag_run', uselist=False), - ) + task_instances = relationship(TI, back_populates="dag_run") DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint( 'scheduler', @@ -303,9 +299,13 @@ def get_task_instances( self, state: Optional[Iterable[TaskInstanceState]] = None, session=None ) -> Iterable[TI]: """Returns the task instances for this dag run""" - tis = session.query(TI).filter( - TI.dag_id == self.dag_id, - TI.execution_date == self.execution_date, + tis = ( + session.query(TI) + .options(joinedload(TI.dag_run)) + .filter( + TI.dag_id == self.dag_id, + TI.run_id == self.run_id, + ) ) if state: @@ -338,8 +338,8 @@ def get_task_instance(self, task_id: str, session: Session = None) -> Optional[T """ return ( session.query(TI) - .filter(TI.dag_id == self.dag_id, TI.execution_date == self.execution_date, TI.task_id == task_id) - .first() + .filter(TI.dag_id == self.dag_id, TI.run_id == self.run_id, TI.task_id == task_id) + .one_or_none() ) def get_dag(self) -> "DAG": @@ -436,7 +436,7 @@ def update_state( callback = callback_requests.DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, - execution_date=self.execution_date, + run_id=self.run_id, is_failure_callback=True, msg='task_failure', ) @@ -451,7 +451,7 @@ def update_state( callback = callback_requests.DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, - execution_date=self.execution_date, + run_id=self.run_id, is_failure_callback=False, msg='success', ) @@ -472,7 +472,7 @@ def update_state( callback = callback_requests.DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, - execution_date=self.execution_date, + run_id=self.run_id, is_failure_callback=True, msg='all_tasks_deadlocked', ) @@ -675,7 +675,7 @@ def verify_integrity(self, session: Session = None): if task.task_id not in task_ids: Stats.incr(f"task_instance_created-{task.task_type}", 1, 1) - ti = TI(task, self.execution_date) + ti = TI(task, execution_date=None, run_id=self.run_id) task_instance_mutation_hook(ti) session.add(ti) @@ -683,9 +683,7 @@ def verify_integrity(self, session: Session = None): session.flush() except IntegrityError as err: self.log.info(str(err)) - self.log.info( - 'Hit IntegrityError while creating the TIs for ' f'{dag.dag_id} - {self.execution_date}.' - ) + self.log.info('Hit IntegrityError while creating the TIs for %s- %s', dag.dag_id, self.run_id) self.log.info('Doing session rollback.') # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback() @@ -695,6 +693,7 @@ def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional """ Get a single DAG Run + :meta private: :param session: Sqlalchemy ORM Session :type session: Session :param dag_id: DAG ID @@ -705,6 +704,11 @@ def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional if one exists. None otherwise. :rtype: airflow.models.DagRun """ + warnings.warn( + "This method is deprecated. Please use SQLAlchemy directly", + DeprecationWarning, + stacklevel=2, + ) return ( session.query(DagRun) .filter( @@ -770,7 +774,7 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = None) - session.query(TI) .filter( TI.dag_id == self.dag_id, - TI.execution_date == self.execution_date, + TI.run_id == self.run_id, TI.task_id.in_(schedulable_ti_ids), ) .update({TI.state: State.SCHEDULED}, synchronize_session=False) @@ -782,7 +786,7 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = None) - session.query(TI) .filter( TI.dag_id == self.dag_id, - TI.execution_date == self.execution_date, + TI.run_id == self.run_id, TI.task_id.in_(dummy_ti_ids), ) .update( diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 489da524ebe04..5cd50a3165b7e 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -16,7 +16,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterable, Union +import warnings +from typing import TYPE_CHECKING, Iterable, Union from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone @@ -24,6 +25,12 @@ from airflow.utils.session import create_session, provide_session from airflow.utils.state import State +if TYPE_CHECKING: + from sqlalchemy import Session + + from airflow.models import DagRun + from airflow.models.baseoperator import BaseOperator + # The key used by SkipMixin to store XCom data. XCOM_SKIPMIXIN_KEY = "skipmixin_key" @@ -37,44 +44,31 @@ class SkipMixin(LoggingMixin): """A Mixin to skip Tasks Instances""" - def _set_state_to_skipped(self, dag_run, execution_date, tasks, session): + def _set_state_to_skipped(self, dag_run: "DagRun", tasks: "Iterable[BaseOperator]", session: "Session"): """Used internally to set state of task instances to skipped from the same dag run.""" task_ids = [d.task_id for d in tasks] now = timezone.utcnow() - if dag_run: - session.query(TaskInstance).filter( - TaskInstance.dag_id == dag_run.dag_id, - TaskInstance.execution_date == dag_run.execution_date, - TaskInstance.task_id.in_(task_ids), - ).update( - { - TaskInstance.state: State.SKIPPED, - TaskInstance.start_date: now, - TaskInstance.end_date: now, - }, - synchronize_session=False, - ) - else: - if execution_date is None: - raise ValueError("Execution date is None and no dag run") - - self.log.warning("No DAG RUN present this should not happen") - # this is defensive against dag runs that are not complete - for task in tasks: - ti = TaskInstance(task, execution_date=execution_date) - ti.state = State.SKIPPED - ti.start_date = now - ti.end_date = now - session.merge(ti) + session.query(TaskInstance).filter( + TaskInstance.dag_id == dag_run.dag_id, + TaskInstance.run_id == dag_run.run_id, + TaskInstance.task_id.in_(task_ids), + ).update( + { + TaskInstance.state: State.SKIPPED, + TaskInstance.start_date: now, + TaskInstance.end_date: now, + }, + synchronize_session=False, + ) @provide_session def skip( self, - dag_run, - execution_date, - tasks, - session=None, + dag_run: "DagRun", + execution_date: "timezone.DateTime", + tasks: "Iterable[BaseOperator]", + session: "Session" = None, ): """ Sets tasks instances to skipped from the same dag run. @@ -91,7 +85,32 @@ def skip( if not tasks: return - self._set_state_to_skipped(dag_run, execution_date, tasks, session) + if execution_date and not dag_run: + from airflow.models.dagrun import DagRun + + warnings.warn( + "Passing an execution_date to `skip()` is deprecated in favour of passing a dag_run", + DeprecationWarning, + stacklevel=2, + ) + + dag_run = ( + session.query(DagRun) + .filter( + DagRun.dag_id == tasks[0].dag_id, + DagRun.execution_date == execution_date, + ) + .one() + ) + elif execution_date and dag_run and execution_date != dag_run.execution_date: + raise ValueError( + "execution_date has a different value to dag_run.execution_date -- please only pass dag_run" + ) + + if dag_run is None: + raise ValueError("dag_run is required") + + self._set_state_to_skipped(dag_run, tasks, session) session.commit() # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available. @@ -154,7 +173,7 @@ def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[str, Iterable self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) with create_session() as session: - self._set_state_to_skipped(dag_run, ti.execution_date, skip_tasks, session=session) + self._set_state_to_skipped(dag_run, skip_tasks, session=session) # For some reason, session.commit() needs to happen before xcom_push. # Otherwise the session is not committed. session.commit() diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 56546e9c7a5c4..e7aaea1af0d73 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -23,7 +23,7 @@ import pickle import signal import warnings -from collections import defaultdict, namedtuple +from collections import defaultdict from datetime import datetime, timedelta from functools import partial from tempfile import NamedTemporaryFile @@ -47,12 +47,13 @@ func, inspect, or_, + tuple_, ) +from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import reconstructor, relationship -from sqlalchemy.orm.attributes import NO_VALUE +from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value from sqlalchemy.orm.session import Session from sqlalchemy.sql.elements import BooleanClauseList -from sqlalchemy.sql.expression import tuple_ from sqlalchemy.sql.sqltypes import BigInteger from airflow import settings @@ -67,6 +68,7 @@ AirflowSkipException, AirflowSmartSensorException, AirflowTaskTimeout, + DagRunNotFound, TaskDeferralError, TaskDeferred, ) @@ -90,7 +92,7 @@ from airflow.utils.net import get_hostname from airflow.utils.operator_helpers import context_to_airflow_vars from airflow.utils.platform import getuser -from airflow.utils.session import provide_session +from airflow.utils.session import create_session, provide_session from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime from airflow.utils.state import DagRunState, State from airflow.utils.timeout import timeout @@ -111,7 +113,7 @@ if TYPE_CHECKING: - from airflow.models.dag import DAG, DagModel + from airflow.models.dag import DAG, DagModel, DagRun @contextlib.contextmanager @@ -202,14 +204,14 @@ def clear_task_instances( ti.external_executor_id = None session.merge(ti) - task_id_by_key[ti.dag_id][ti.execution_date][ti.try_number].add(ti.task_id) + task_id_by_key[ti.dag_id][ti.run_id][ti.try_number].add(ti.task_id) if task_id_by_key: # Clear all reschedules related to the ti to clear # This is an optimization for the common case where all tis are for a small number - # of dag_id, execution_date and try_number. Use a nested dict of dag_id, - # execution_date, try_number and task_id to construct the where clause in a + # of dag_id, run_id and try_number. Use a nested dict of dag_id, + # run_id, try_number and task_id to construct the where clause in a # hierarchical manner. This speeds up the delete statement by more than 40x for # large number of tis (50k+). conditions = or_( @@ -217,16 +219,16 @@ def clear_task_instances( TR.dag_id == dag_id, or_( and_( - TR.execution_date == execution_date, + TR.run_id == run_id, or_( and_(TR.try_number == try_number, TR.task_id.in_(task_ids)) for try_number, task_ids in task_tries.items() ), ) - for execution_date, task_tries in dates.items() + for run_id, task_tries in run_ids.items() ), ) - for dag_id, dates in task_id_by_key.items() + for dag_id, run_ids in task_id_by_key.items() ) delete_qry = TR.__table__.delete().where(conditions) @@ -251,16 +253,16 @@ def clear_task_instances( if dag_run_state is not False and tis: from airflow.models.dagrun import DagRun # Avoid circular import - dates_by_dag_id = defaultdict(set) + run_ids_by_dag_id = defaultdict(set) for instance in tis: - dates_by_dag_id[instance.dag_id].add(instance.execution_date) + run_ids_by_dag_id[instance.dag_id].add(instance.run_id) drs = ( session.query(DagRun) .filter( or_( - and_(DagRun.dag_id == dag_id, DagRun.execution_date.in_(dates)) - for dag_id, dates in dates_by_dag_id.items() + and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids)) + for dag_id, run_ids in run_ids_by_dag_id.items() ) ) .all() @@ -277,22 +279,22 @@ class TaskInstanceKey(NamedTuple): dag_id: str task_id: str - execution_date: datetime + run_id: str try_number: int = 1 @property - def primary(self) -> Tuple[str, str, datetime]: + def primary(self) -> Tuple[str, str, str]: """Return task instance primary key part of the key""" - return self.dag_id, self.task_id, self.execution_date + return self.dag_id, self.task_id, self.run_id @property def reduced(self) -> 'TaskInstanceKey': """Remake the key by subtracting 1 from try number to match in memory information""" - return TaskInstanceKey(self.dag_id, self.task_id, self.execution_date, max(1, self.try_number - 1)) + return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1)) def with_try_number(self, try_number: int) -> 'TaskInstanceKey': """Returns TaskInstanceKey with provided ``try_number``""" - return TaskInstanceKey(self.dag_id, self.task_id, self.execution_date, try_number) + return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number) @property def key(self) -> "TaskInstanceKey": @@ -321,7 +323,7 @@ class TaskInstance(Base, LoggingMixin): task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) - execution_date = Column(UtcDateTime, primary_key=True) + run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) duration = Column(Float) @@ -359,9 +361,9 @@ class TaskInstance(Base, LoggingMixin): __table_args__ = ( Index('ti_dag_state', dag_id, state), - Index('ti_dag_date', dag_id, execution_date), + Index('ti_dag_run', dag_id, run_id), Index('ti_state', state), - Index('ti_state_lkp', dag_id, task_id, execution_date, state), + Index('ti_state_lkp', dag_id, task_id, run_id, state), Index('ti_pool', pool, state, priority_weight), Index('ti_job_id', job_id), Index('ti_trigger_id', trigger_id), @@ -371,6 +373,12 @@ class TaskInstance(Base, LoggingMixin): name='task_instance_trigger_id_fkey', ondelete='CASCADE', ), + ForeignKeyConstraint( + [dag_id, run_id], + ["dag_run.dag_id", "dag_run.run_id"], + name='task_instance_dag_run_fkey', + ondelete="CASCADE", + ), ) dag_model = relationship( @@ -379,6 +387,7 @@ class TaskInstance(Base, LoggingMixin): foreign_keys=dag_id, uselist=False, innerjoin=True, + viewonly=True, ) trigger = relationship( @@ -389,27 +398,52 @@ class TaskInstance(Base, LoggingMixin): innerjoin=True, ) - def __init__(self, task, execution_date: datetime, state: Optional[str] = None): + dag_run = relationship("DagRun", back_populates="task_instances") + + execution_date = association_proxy("dag_run", "execution_date") + + def __init__( + self, task, execution_date: Optional[datetime] = None, run_id: str = None, state: Optional[str] = None + ): super().__init__() self.dag_id = task.dag_id self.task_id = task.task_id self.refresh_from_task(task) self._log = logging.getLogger("airflow.task") - # make sure we have a localized execution_date stored in UTC - if execution_date and not timezone.is_localized(execution_date): - self.log.warning( - "execution date %s has no timezone information. Using default from dag or system", - execution_date, - ) - if self.task.has_dag(): - execution_date = timezone.make_aware(execution_date, self.task.dag.timezone) - else: - execution_date = timezone.make_aware(execution_date) + if run_id is None and execution_date is not None: + from airflow.models.dagrun import DagRun # Avoid circular import - execution_date = timezone.convert_to_utc(execution_date) + warnings.warn( + "Passing an execution_date to `TaskInstance()` is deprecated in favour of passing a run_id", + DeprecationWarning, + # Stack level is 4 because SQLA adds some wrappers around the constructor + stacklevel=4, + ) + # make sure we have a localized execution_date stored in UTC + if execution_date and not timezone.is_localized(execution_date): + self.log.warning( + "execution date %s has no timezone information. Using default from dag or system", + execution_date, + ) + if self.task.has_dag(): + execution_date = timezone.make_aware(execution_date, self.task.dag.timezone) + else: + execution_date = timezone.make_aware(execution_date) + + execution_date = timezone.convert_to_utc(execution_date) + with create_session() as session: + run_id = ( + session.query(DagRun.run_id) + .filter_by(dag_id=self.dag_id, execution_date=execution_date) + .scalar() + ) + if run_id is None: + raise DagRunNotFound( + f"DagRun for {self.dag_id!r} with date {execution_date} not found" + ) from None - self.execution_date = execution_date + self.run_id = run_id self.try_number = 0 self.unixname = getuser() @@ -466,22 +500,6 @@ def next_try_number(self): """Setting Next Try Number""" return self._try_number + 1 - @property - def run_id(self): - """Fetches the run_id from the associated DagRun""" - # TODO: Remove this once run_id is added as a column in TaskInstance - - # IF we have pre-loaded it, just use that - info = inspect(self) - if info.attrs.dag_run.loaded_value is not NO_VALUE: - return self.dag_un.run_id - # _Don't_ use provide/create_session here, as we do not want to commit on this session (as this is - # called from the scheduler critical section)! - dag_run = self.get_dagrun(session=settings.Session()) - - if dag_run: - return dag_run.run_id - def command_as_list( self, mark_success=False, @@ -525,7 +543,6 @@ def command_as_list( self.dag_id, self.task_id, run_id=self.run_id, - execution_date=self.execution_date, mark_success=mark_success, ignore_all_deps=ignore_all_deps, ignore_task_deps=ignore_task_deps, @@ -545,7 +562,6 @@ def generate_command( dag_id: str, task_id: str, run_id: str = None, - execution_date: datetime = None, mark_success: bool = False, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, @@ -562,8 +578,6 @@ def generate_command( """ Generates the shell command required to execute this task instance. - One of run_id or execution_date must be passed - :param dag_id: DAG ID :type dag_id: str :param task_id: Task ID @@ -601,13 +615,7 @@ def generate_command( :return: shell command that can be used to run the task instance :rtype: list[str] """ - cmd = ["airflow", "tasks", "run", dag_id, task_id] - if run_id: - cmd.append(run_id) - elif execution_date: - cmd.append(execution_date.isoformat()) - else: - raise ValueError("One of run_id and execution_date must be provided") + cmd = ["airflow", "tasks", "run", dag_id, task_id, run_id] if mark_success: cmd.extend(["--mark-success"]) if pickle_id: @@ -671,7 +679,7 @@ def current_state(self, session=None) -> str: .filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, - TaskInstance.execution_date == self.execution_date, + TaskInstance.run_id == self.run_id, ) .all() ) @@ -711,11 +719,11 @@ def refresh_from_db(self, session=None, lock_for_update=False) -> None: qry = session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, - TaskInstance.execution_date == self.execution_date, + TaskInstance.run_id == self.run_id, ) if lock_for_update: - ti = qry.with_for_update().first() + ti: Optional[TaskInstance] = qry.with_for_update().first() else: ti = qry.first() if ti: @@ -788,7 +796,7 @@ def clear_xcom_data(self, session=None): @property def key(self) -> TaskInstanceKey: """Returns a tuple that identifies the task instance uniquely""" - return TaskInstanceKey(self.dag_id, self.task_id, self.execution_date, self.try_number) + return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number) @provide_session def set_state(self, state: str, session=None): @@ -839,7 +847,7 @@ def are_dependents_done(self, session=None): ti = session.query(func.count(TaskInstance.task_id)).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(task.downstream_task_ids), - TaskInstance.execution_date == self.execution_date, + TaskInstance.run_id == self.run_id, TaskInstance.state.in_([State.SKIPPED, State.SUCCESS]), ) count = ti[0][0] @@ -1043,7 +1051,7 @@ def get_failed_dep_statuses(self, dep_context=None, session=None): yield dep_status def __repr__(self): - return f"" + return f"" def next_retry_datetime(self): """ @@ -1093,13 +1101,16 @@ def get_dagrun(self, session: Session = None): :param session: SQLAlchemy ORM Session :return: DagRun """ + info = inspect(self) + if info.attrs.dag_run.loaded_value is not NO_VALUE: + return self.dag_run + from airflow.models.dagrun import DagRun # Avoid circular import - dr = ( - session.query(DagRun) - .filter(DagRun.dag_id == self.dag_id, DagRun.execution_date == self.execution_date) - .first() - ) + dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id).one() + + # Record it in the instance for next time. This means that `self.execution_date` will work correctly + set_committed_value(self, 'dag_run', dr) return dr @@ -1287,20 +1298,22 @@ def _run_raw_task( self.job_id = job_id self.hostname = get_hostname() self.pid = os.getpid() - session.merge(self) - session.commit() + if not test_mode: + session.merge(self) + session.commit() actual_start_date = timezone.utcnow() Stats.incr(f'ti.start.{task.dag_id}.{task.task_id}') try: if not mark_success: context = self.get_template_context() self._prepare_and_execute_task_with_callbacks(context, task) - self.refresh_from_db(lock_for_update=True) + if not test_mode: + self.refresh_from_db(lock_for_update=True, session=session) self.state = State.SUCCESS except TaskDeferred as defer: # The task has signalled it wants to defer execution based on # a trigger. - self._defer_task(defer=defer) + self._defer_task(defer=defer, session=session) self.log.info( 'Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s', self.dag_id, @@ -1311,7 +1324,7 @@ def _run_raw_task( if not test_mode: session.add(Log(self.state, self)) session.merge(self) - session.commit() + session.commit() return except AirflowSmartSensorException as e: self.log.info(e) @@ -1321,29 +1334,31 @@ def _run_raw_task( # log only if exception has any arguments to prevent log flooding if e.args: self.log.info(e) - self.refresh_from_db(lock_for_update=True) + if not test_mode: + self.refresh_from_db(lock_for_update=True, session=session) self.state = State.SKIPPED except AirflowRescheduleException as reschedule_exception: - self.refresh_from_db() - self._handle_reschedule(actual_start_date, reschedule_exception, test_mode) + self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) + session.commit() return except (AirflowFailException, AirflowSensorTimeout) as e: # If AirflowFailException is raised, task should not retry. # If a sensor in reschedule mode reaches timeout, task should not retry. - self.refresh_from_db() - self.handle_failure(e, test_mode, force_fail=True, error_file=error_file) + self.handle_failure(e, test_mode, force_fail=True, error_file=error_file, session=session) + session.commit() raise except AirflowException as e: - self.refresh_from_db() # for case when task is marked as success/failed externally # current behavior doesn't hit the success callback if self.state in {State.SUCCESS, State.FAILED}: return else: - self.handle_failure(e, test_mode, error_file=error_file) + self.handle_failure(e, test_mode, error_file=error_file, session=session) + session.commit() raise except (Exception, KeyboardInterrupt) as e: - self.handle_failure(e, test_mode, error_file=error_file) + self.handle_failure(e, test_mode, error_file=error_file, session=session) + session.commit() raise finally: Stats.incr(f'ti.finish.{task.dag_id}.{task.task_id}.{self.state}') @@ -1356,7 +1371,7 @@ def _run_raw_task( session.add(Log(self.state, self)) session.merge(self) - session.commit() + session.commit() def _prepare_and_execute_task_with_callbacks(self, context, task): """Prepare Task for Execution""" @@ -1613,6 +1628,7 @@ def _handle_reschedule(self, actual_start_date, reschedule_exception, test_mode= # Don't record reschedule request in test mode if test_mode: return + self.refresh_from_db(session) self.end_date = timezone.utcnow() self.set_duration() @@ -1621,7 +1637,7 @@ def _handle_reschedule(self, actual_start_date, reschedule_exception, test_mode= session.add( TaskReschedule( self.task, - self.execution_date, + self.run_id, self._try_number, actual_start_date, self.end_date, @@ -1662,6 +1678,8 @@ def handle_failure( # can send its runtime errors for access by failure callback if error_file: set_error_file(error_file, error) + if not test_mode: + self.refresh_from_db(session) task = self.task self.end_date = timezone.utcnow() @@ -1671,8 +1689,8 @@ def handle_failure( if not test_mode: session.add(Log(State.FAILED, self)) - # Log failure duration - session.add(TaskFail(task, self.execution_date, self.start_date, self.end_date)) + # Log failure duration + session.add(TaskFail(task, self.execution_date, self.start_date, self.end_date)) # Set state correctly and figure out how to log it and decide whether # to email @@ -1702,7 +1720,7 @@ def handle_failure( if not test_mode: session.merge(self) - session.commit() + session.flush() @provide_session def handle_failure_with_callback( @@ -1724,30 +1742,18 @@ def is_eligible_to_retry(self): return self.task.retries and self.try_number <= self.max_tries - @provide_session - def get_template_context(self, session=None) -> Context: + def get_template_context(self, session: Session = None) -> Context: """Return TI Context""" + # Do not use provide_session here -- it expunges everything on exit! + if not session: + session = settings.Session() task = self.task from airflow import macros integrate_macros_plugins() - dag_run = self.get_dagrun() - - # FIXME: Many tests don't create a DagRun. We should fix the tests. - if dag_run is None: - FakeDagRun = namedtuple( - "FakeDagRun", - # A minimal set of attributes to keep things working. - "conf data_interval_start data_interval_end external_trigger run_id", - ) - dag_run = FakeDagRun( - conf=None, - data_interval_start=None, - data_interval_end=None, - external_trigger=False, - run_id="", - ) + # Ensure that the dag_run is loaded -- otherwise `self.execution_date` may not work + dag_run = self.get_dagrun(session) params = {} # type: Dict[str, Any] with contextlib.suppress(AttributeError): @@ -1985,7 +1991,7 @@ def get_prev_ds_nodash() -> Optional[str]: replacement='prev_data_interval_start_success', ), 'prev_start_date_success': lazy_object_proxy.Proxy(get_prev_start_date_success), - 'run_id': dag_run.run_id, + 'run_id': self.run_id, 'task': task, 'task_instance': self, 'task_instance_key_str': f"{task.dag_id}__{task.task_id}__{ds_nodash}", @@ -2005,11 +2011,12 @@ def get_prev_ds_nodash() -> Optional[str]: 'yesterday_ds_nodash': deprecated_proxy(get_yesterday_ds_nodash, key='yesterday_ds_nodash'), } - def get_rendered_template_fields(self): + @provide_session + def get_rendered_template_fields(self, session=None): """Fetch rendered template fields from DB""" from airflow.models.renderedtifields import RenderedTaskInstanceFields - rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self) + rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session) if rendered_task_instance_fields: for field_name, rendered_value in rendered_task_instance_fields.items(): setattr(self.task, field_name, rendered_value) @@ -2184,10 +2191,11 @@ def xcom_push( :param session: Sqlalchemy ORM Session :type session: Session """ - if execution_date and execution_date < self.execution_date: + self_execution_date = self.get_dagrun(session).execution_date + if execution_date and execution_date < self_execution_date: raise ValueError( 'execution_date can not be in the past (current ' - 'execution_date is {}; received {})'.format(self.execution_date, execution_date) + 'execution_date is {}; received {})'.format(self_execution_date, execution_date) ) XCom.set( @@ -2195,7 +2203,7 @@ def xcom_push( value=value, task_id=self.task_id, dag_id=self.dag_id, - execution_date=execution_date or self.execution_date, + execution_date=execution_date or self_execution_date, session=session, ) @@ -2242,8 +2250,10 @@ def xcom_pull( if dag_id is None: dag_id = self.dag_id + execution_date = self.get_dagrun(session).execution_date + query = XCom.get_many( - execution_date=self.execution_date, + execution_date=execution_date, key=key, dag_ids=dag_id, task_ids=task_ids, @@ -2299,20 +2309,20 @@ def filter_for_tis(tis: Iterable[Union["TaskInstance", TaskInstanceKey]]) -> Opt first = tis[0] dag_id = first.dag_id - execution_date = first.execution_date + run_id = first.run_id first_task_id = first.task_id - # Common path optimisations: when all TIs are for the same dag_id and execution_date, or same dag_id + # Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id # and task_id -- this can be over 150x for huge numbers of TIs (20k+) - if all(t.dag_id == dag_id and t.execution_date == execution_date for t in tis): + if all(t.dag_id == dag_id and t.run_id == run_id for t in tis): return and_( TaskInstance.dag_id == dag_id, - TaskInstance.execution_date == execution_date, + TaskInstance.run_id == run_id, TaskInstance.task_id.in_(t.task_id for t in tis), ) if all(t.dag_id == dag_id and t.task_id == first_task_id for t in tis): return and_( TaskInstance.dag_id == dag_id, - TaskInstance.execution_date.in_(t.execution_date for t in tis), + TaskInstance.run_id.in_(t.run_id for t in tis), TaskInstance.task_id == first_task_id, ) @@ -2321,12 +2331,12 @@ def filter_for_tis(tis: Iterable[Union["TaskInstance", TaskInstanceKey]]) -> Opt and_( TaskInstance.dag_id == ti.dag_id, TaskInstance.task_id == ti.task_id, - TaskInstance.execution_date == ti.execution_date, + TaskInstance.run_id == ti.run_id, ) for ti in tis ) else: - return tuple_(TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.execution_date).in_( + return tuple_(TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id).in_( [ti.key.primary for ti in tis] ) @@ -2346,7 +2356,7 @@ class SimpleTaskInstance: def __init__(self, ti: TaskInstance): self._dag_id: str = ti.dag_id self._task_id: str = ti.task_id - self._execution_date: datetime = ti.execution_date + self._run_id: datetime = ti.run_id self._start_date: datetime = ti.start_date self._end_date: datetime = ti.end_date self._try_number: int = ti.try_number @@ -2371,8 +2381,8 @@ def task_id(self) -> str: return self._task_id @property - def execution_date(self) -> datetime: - return self._execution_date + def run_id(self) -> str: + return self._run_id @property def start_date(self) -> datetime: @@ -2410,36 +2420,10 @@ def key(self) -> TaskInstanceKey: def executor_config(self): return self._executor_config - @provide_session - def construct_task_instance(self, session=None, lock_for_update=False) -> TaskInstance: - """ - Construct a TaskInstance from the database based on the primary key - - :param session: DB session. - :param lock_for_update: if True, indicates that the database should - lock the TaskInstance (issuing a FOR UPDATE clause) until the - session is committed. - :return: the task instance constructed - """ - qry = session.query(TaskInstance).filter( - TaskInstance.dag_id == self._dag_id, - TaskInstance.task_id == self._task_id, - TaskInstance.execution_date == self._execution_date, - ) - - if lock_for_update: - ti = qry.with_for_update().first() - else: - ti = qry.first() - return ti - STATICA_HACK = True globals()['kcah_acitats'[::-1].upper()] = False if STATICA_HACK: # pragma: no cover + from airflow.job.base_job import BaseJob - from airflow.jobs.base_job import BaseJob - from airflow.models.dagrun import DagRun - - TaskInstance.dag_run = relationship(DagRun) TaskInstance.queued_by_job = relationship(BaseJob) diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py index 293021cfca061..55ef754da27ff 100644 --- a/airflow/models/taskreschedule.py +++ b/airflow/models/taskreschedule.py @@ -17,6 +17,8 @@ # under the License. """TaskReschedule tracks rescheduled task instances.""" from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc +from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy.orm import relationship from airflow.models.base import COLLATION_ARGS, ID_LEN, Base from airflow.utils.session import provide_session @@ -31,7 +33,7 @@ class TaskReschedule(Base): id = Column(Integer, primary_key=True) task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) - execution_date = Column(UtcDateTime, nullable=False) + run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) try_number = Column(Integer, nullable=False) start_date = Column(UtcDateTime, nullable=False) end_date = Column(UtcDateTime, nullable=False) @@ -39,19 +41,27 @@ class TaskReschedule(Base): reschedule_date = Column(UtcDateTime, nullable=False) __table_args__ = ( - Index('idx_task_reschedule_dag_task_date', dag_id, task_id, execution_date, unique=False), + Index('idx_task_reschedule_dag_task_run', dag_id, task_id, run_id, unique=False), ForeignKeyConstraint( - [task_id, dag_id, execution_date], - ['task_instance.task_id', 'task_instance.dag_id', 'task_instance.execution_date'], - name='task_reschedule_dag_task_date_fkey', + [dag_id, task_id, run_id], + ['task_instance.dag_id', 'task_instance.task_id', 'task_instance.run_id'], + name='task_reschedule_ti_fkey', + ondelete='CASCADE', + ), + ForeignKeyConstraint( + [dag_id, run_id], + ['dag_run.dag_id', 'dag_run.run_id'], + name='task_reschedule_dr_fkey', ondelete='CASCADE', ), ) + dag_run = relationship("DagRun") + execution_date = association_proxy("dag_run", "execution_date") - def __init__(self, task, execution_date, try_number, start_date, end_date, reschedule_date): + def __init__(self, task, run_id, try_number, start_date, end_date, reschedule_date): self.dag_id = task.dag_id self.task_id = task.task_id - self.execution_date = execution_date + self.run_id = run_id self.try_number = try_number self.start_date = start_date self.end_date = end_date @@ -81,7 +91,7 @@ def query_for_task_instance(task_instance, descending=False, session=None, try_n qry = session.query(TR).filter( TR.dag_id == task_instance.dag_id, TR.task_id == task_instance.task_id, - TR.execution_date == task_instance.execution_date, + TR.run_id == task_instance.run_id, TR.try_number == try_number, ) if descending: diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index e89d6a861343e..05c13422f2838 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -34,6 +34,7 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator, BaseOperatorLink from airflow.models.taskinstance import TaskInstance +from airflow.models.xcom import XCom from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url @@ -60,8 +61,12 @@ class BigQueryConsoleLink(BaseOperatorLink): name = 'BigQuery Console' def get_link(self, operator, dttm): - ti = TaskInstance(task=operator, execution_date=dttm) - job_id = ti.xcom_pull(task_ids=operator.task_id, key='job_id') + job_id = XCom.get_one( + dag_id=operator.dag.dag_id, + task_id=operator.task_id, + execution_date=dttm, + key='job_id', + ) return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) if job_id else '' diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index f0fd0e03896c0..0019c41046c42 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -125,7 +125,7 @@ def _validate_input_values(self) -> None: "The mode must be one of {valid_modes}," "'{d}.{t}'; received '{m}'.".format( valid_modes=self.valid_modes, - d=self.dag.dag_id if self.dag else "", + d=self.dag.dag_id if self.has_dag() else "", t=self.task_id, m=self.mode, ) diff --git a/airflow/sensors/smart_sensor.py b/airflow/sensors/smart_sensor.py index ec6acefe2c733..1e8c827e87d9c 100644 --- a/airflow/sensors/smart_sensor.py +++ b/airflow/sensors/smart_sensor.py @@ -15,8 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - - import datetime import json import logging @@ -27,7 +25,7 @@ from sqlalchemy import and_, or_, tuple_ from airflow.exceptions import AirflowException, AirflowTaskTimeout -from airflow.models import BaseOperator, SensorInstance, SkipMixin, TaskInstance +from airflow.models import BaseOperator, DagRun, SensorInstance, SkipMixin, TaskInstance from airflow.settings import LOGGING_CLASS_PATH from airflow.stats import Stats from airflow.utils import helpers, timezone @@ -390,6 +388,7 @@ def _update_ti_hostname(self, sensor_works, session=None): :param sensor_works: Smart sensor internal object for a sensor task. :param session: The sqlalchemy session. """ + DR = DagRun TI = TaskInstance def update_ti_hostname_with_count(count, sensor_works): @@ -399,18 +398,17 @@ def update_ti_hostname_with_count(count, sensor_works): and_( TI.dag_id == ti_key.dag_id, TI.task_id == ti_key.task_id, - TI.execution_date == ti_key.execution_date, + DR.execution_date == ti_key.execution_date, ) for ti_key in sensor_works ) else: ti_keys = [(x.dag_id, x.task_id, x.execution_date) for x in sensor_works] ti_filter = or_( - tuple_(TI.dag_id, TI.task_id, TI.execution_date) == ti_key for ti_key in ti_keys + tuple_(TI.dag_id, TI.task_id, DR.execution_date) == ti_key for ti_key in ti_keys ) - tis = session.query(TI).filter(ti_filter).all() - for ti in tis: + for ti in session.query(TI).join(TI.dag_run).filter(ti_filter): ti.hostname = self.hostname session.commit() diff --git a/airflow/sentry.py b/airflow/sentry.py index 51fe26f2c31dc..340b660934b45 100644 --- a/airflow/sentry.py +++ b/airflow/sentry.py @@ -130,13 +130,9 @@ def add_breadcrumbs(self, task_instance, session=None): """Function to add breadcrumbs inside of a task_instance.""" if session is None: return - execution_date = task_instance.execution_date - task = task_instance.task - dag = task.dag - task_instances = dag.get_task_instances( + dr = task_instance.get_dagrun(session) + task_instances = dr.get_task_instances( state={State.SUCCESS, State.FAILED}, - end_date=execution_date, - start_date=execution_date, session=session, ) diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py index b8b1f3aaa5bf8..6c747c7656e2d 100644 --- a/airflow/ti_deps/dep_context.py +++ b/airflow/ti_deps/dep_context.py @@ -16,11 +16,16 @@ # specific language governing permissions and limitations # under the License. -import pendulum +from typing import TYPE_CHECKING, List + from sqlalchemy.orm.session import Session from airflow.utils.state import State +if TYPE_CHECKING: + from airflow.models.dagrun import DagRun + from airflow.models.taskinstance import TaskInstance + class DepContext: """ @@ -85,23 +90,16 @@ def __init__( self.ignore_ti_state = ignore_ti_state self.finished_tasks = finished_tasks - def ensure_finished_tasks(self, dag, execution_date: pendulum.DateTime, session: Session): + def ensure_finished_tasks(self, dag_run: "DagRun", session: Session) -> "List[TaskInstance]": """ This method makes sure finished_tasks is populated if it's currently None. This is for the strange feature of running tasks without dag_run. - :param dag: The DAG for which to find finished tasks - :type dag: airflow.models.DAG - :param execution_date: The execution_date to look for - :param session: Database session to use + :param dag_run: The DagRun for which to find finished tasks + :type dag_run: airflow.models.DagRun :return: A list of all the finished tasks of this DAG and execution_date :rtype: list[airflow.models.TaskInstance] """ if self.finished_tasks is None: - self.finished_tasks = dag.get_task_instances( - start_date=execution_date, - end_date=execution_date, - state=State.finished, - session=session, - ) + self.finished_tasks = dag_run.get_task_instances(state=State.finished, session=session) return self.finished_tasks diff --git a/airflow/ti_deps/deps/dagrun_exists_dep.py b/airflow/ti_deps/deps/dagrun_exists_dep.py index 6e00b830cb290..0aa21e8c8f295 100644 --- a/airflow/ti_deps/deps/dagrun_exists_dep.py +++ b/airflow/ti_deps/deps/dagrun_exists_dep.py @@ -29,27 +29,9 @@ class DagrunRunningDep(BaseTIDep): @provide_session def _get_dep_statuses(self, ti, session, dep_context): - dag = ti.task.dag - dagrun = ti.get_dagrun(session) - if not dagrun: - # The import is needed here to avoid a circular dependency - from airflow.models.dagrun import DagRun - - running_dagruns = DagRun.find( - dag_id=dag.dag_id, state=State.RUNNING, external_trigger=False, session=session + dr = ti.get_dagrun(session) + if dr.state != State.RUNNING: + yield self._failing_status( + reason="Task instance's dagrun was not in the 'running' state but in " + "the state '{}'.".format(dr.state) ) - - if len(running_dagruns) >= dag.max_active_runs: - reason = ( - "The maximum number of active dag runs ({}) for this task " - "instance's DAG '{}' has been reached.".format(dag.max_active_runs, ti.dag_id) - ) - else: - reason = "Unknown reason" - yield self._failing_status(reason=f"Task instance's dagrun did not exist: {reason}.") - else: - if dagrun.state != State.RUNNING: - yield self._failing_status( - reason="Task instance's dagrun was not in the 'running' state but in " - "the state '{}'.".format(dagrun.state) - ) diff --git a/airflow/ti_deps/deps/dagrun_id_dep.py b/airflow/ti_deps/deps/dagrun_id_dep.py index 186ab7cfd6787..a60951414da90 100644 --- a/airflow/ti_deps/deps/dagrun_id_dep.py +++ b/airflow/ti_deps/deps/dagrun_id_dep.py @@ -32,7 +32,7 @@ class DagrunIdDep(BaseTIDep): @provide_session def _get_dep_statuses(self, ti, session, dep_context=None): """ - Determines if the DagRun ID is valid for scheduling from scheduler. + Determines if the DagRun is valid for scheduling from scheduler. :param ti: the task instance to get the dependency status for :type ti: airflow.models.TaskInstance @@ -44,12 +44,7 @@ def _get_dep_statuses(self, ti, session, dep_context=None): """ dagrun = ti.get_dagrun(session) - if not dagrun or not dagrun.run_id or dagrun.run_type != DagRunType.BACKFILL_JOB: - yield self._passing_status( - reason=f"Task's DagRun doesn't exist or run_id is either NULL " - f"or run_type is not {DagRunType.BACKFILL_JOB}" - ) - else: + if dagrun.run_type == DagRunType.BACKFILL_JOB: yield self._failing_status( - reason=f"Task's DagRun run_id is not NULL " f"and run type is {DagRunType.BACKFILL_JOB}" + reason=f"Task's DagRun run_type is {dagrun.run_type} and cannot be run by the scheduler" ) diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py b/airflow/ti_deps/deps/not_previously_skipped_dep.py index 3d1bde949ece7..e9df0edd222a6 100644 --- a/airflow/ti_deps/deps/not_previously_skipped_dep.py +++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py @@ -39,7 +39,7 @@ def _get_dep_statuses(self, ti, session, dep_context): upstream = ti.task.get_direct_relatives(upstream=True) - finished_tasks = dep_context.ensure_finished_tasks(ti.task.dag, ti.execution_date, session) + finished_tasks = dep_context.ensure_finished_tasks(ti.get_dagrun(session), session) finished_task_ids = {t.task_id for t in finished_tasks} diff --git a/airflow/ti_deps/deps/runnable_exec_date_dep.py b/airflow/ti_deps/deps/runnable_exec_date_dep.py index 3986ef14b67ba..0607c11355f30 100644 --- a/airflow/ti_deps/deps/runnable_exec_date_dep.py +++ b/airflow/ti_deps/deps/runnable_exec_date_dep.py @@ -33,20 +33,21 @@ def _get_dep_statuses(self, ti, session, dep_context): # don't consider runs that are executed in the future unless # specified by config and schedule_interval is None - if ti.execution_date > cur_date and not ti.task.dag.allow_future_exec_dates: + logical_date = ti.get_dagrun(session).execution_date + if logical_date > cur_date and not ti.task.dag.allow_future_exec_dates: yield self._failing_status( reason="Execution date {} is in the future (the current " - "date is {}).".format(ti.execution_date.isoformat(), cur_date.isoformat()) + "date is {}).".format(logical_date.isoformat(), cur_date.isoformat()) ) - if ti.task.end_date and ti.execution_date > ti.task.end_date: + if ti.task.end_date and logical_date > ti.task.end_date: yield self._failing_status( reason="The execution date is {} but this is after the task's end date " - "{}.".format(ti.execution_date.isoformat(), ti.task.end_date.isoformat()) + "{}.".format(logical_date.isoformat(), ti.task.end_date.isoformat()) ) - if ti.task.dag and ti.task.dag.end_date and ti.execution_date > ti.task.dag.end_date: + if ti.task.dag and ti.task.dag.end_date and logical_date > ti.task.dag.end_date: yield self._failing_status( reason="The execution date is {} but this is after the task's DAG's " - "end date {}.".format(ti.execution_date.isoformat(), ti.task.dag.end_date.isoformat()) + "end date {}.".format(logical_date.isoformat(), ti.task.dag.end_date.isoformat()) ) diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index f04cf3123691e..5d72410053c52 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -66,7 +66,7 @@ def _get_dep_statuses(self, ti, session, dep_context): return # see if the task name is in the task upstream for our task successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti( - ti=ti, finished_tasks=dep_context.ensure_finished_tasks(ti.task.dag, ti.execution_date, session) + ti=ti, finished_tasks=dep_context.ensure_finished_tasks(ti.get_dagrun(session), session) ) yield from self._evaluate_trigger_rule( diff --git a/airflow/utils/callback_requests.py b/airflow/utils/callback_requests.py index 89ffe52e3fdd5..8ed587be3772d 100644 --- a/airflow/utils/callback_requests.py +++ b/airflow/utils/callback_requests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from datetime import datetime from typing import Optional from airflow.models.taskinstance import SimpleTaskInstance @@ -71,7 +70,7 @@ class DagCallbackRequest(CallbackRequest): :param full_filepath: File Path to use to run the callback :param dag_id: DAG ID - :param execution_date: Execution Date for the DagRun + :param run_id: Run ID for the DagRun :param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback :param msg: Additional Message that can be used for logging """ @@ -80,13 +79,13 @@ def __init__( self, full_filepath: str, dag_id: str, - execution_date: datetime, + run_id: str, is_failure_callback: Optional[bool] = True, msg: Optional[str] = None, ): super().__init__(full_filepath=full_filepath, msg=msg) self.dag_id = dag_id - self.execution_date = execution_date + self.run_id = run_id self.is_failure_callback = is_failure_callback diff --git a/airflow/www/auth.py b/airflow/www/auth.py index 0b1d9ed166058..cd439283d6011 100644 --- a/airflow/www/auth.py +++ b/airflow/www/auth.py @@ -32,6 +32,8 @@ def has_access(permissions: Optional[Sequence[Tuple[str, str]]] = None) -> Calla def requires_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): + __tracebackhide__ = True # Hide from pytest traceback. + appbuilder = current_app.appbuilder if not g.user.is_anonymous and not appbuilder.sm.current_user_has_permissions(): return ( diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py index 8500f0d7e38c1..f6f2ed0b2a3d5 100644 --- a/airflow/www/decorators.py +++ b/airflow/www/decorators.py @@ -39,6 +39,7 @@ def action_logging(f: T) -> T: @functools.wraps(f) def wrapper(*args, **kwargs): + __tracebackhide__ = True # Hide from pytest traceback. with create_session() as session: if g.user.is_anonymous: diff --git a/airflow/www/utils.py b/airflow/www/utils.py index db86783a29a78..cc1729529e7df 100644 --- a/airflow/www/utils.py +++ b/airflow/www/utils.py @@ -18,6 +18,7 @@ import json import textwrap import time +from typing import Any from urllib.parse import urlencode import markdown @@ -29,6 +30,7 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface from pygments import highlight, lexers from pygments.formatters import HtmlFormatter +from sqlalchemy.ext.associationproxy import AssociationProxy from airflow.models import errors from airflow.utils import timezone @@ -225,8 +227,8 @@ def task_instance_link(attr): """Generates a URL to the Graph view for a TaskInstance.""" dag_id = attr.get('dag_id') task_id = attr.get('task_id') - execution_date = attr.get('execution_date') - url = url_for('Airflow.task', dag_id=dag_id, task_id=task_id, execution_date=execution_date.isoformat()) + execution_date = attr.get('dag_run.execution_date') or attr.get('execution_date') or timezone.utcnow() + url = url_for('Airflow.task', dag_id=dag_id, task_id=task_id) url_root = url_for( 'Airflow.graph', dag_id=dag_id, root=task_id, execution_date=execution_date.isoformat() ) @@ -311,7 +313,7 @@ def dag_run_link(attr): """Generates a URL to the Graph view for a DagRun.""" dag_id = attr.get('dag_id') run_id = attr.get('run_id') - execution_date = attr.get('execution_date') + execution_date = attr.get('dag_run.exectuion_date') or attr.get('execution_date') url = url_for('Airflow.graph', dag_id=dag_id, run_id=run_id, execution_date=execution_date) return Markup('{run_id}').format(url=url, run_id=run_id) @@ -456,6 +458,13 @@ def clean_column_names(): self.list_columns = {k.lstrip('_'): v for k, v in self.list_columns.items()} clean_column_names() + # Support for AssociationProxy in search and list columns + for desc in self.obj.__mapper__.all_orm_descriptors: + if not isinstance(desc, AssociationProxy): + continue + proxy_instance = getattr(self.obj, desc.value_attr) + self.list_columns[desc.value_attr] = proxy_instance.remote_attr.prop.columns[0] + self.list_properties[desc.value_attr] = proxy_instance.remote_attr.prop def is_utcdatetime(self, col_name): """Check if the datetime is a UTC one.""" @@ -483,6 +492,12 @@ def is_extendedjson(self, col_name): ) return False + def get_col_default(self, col_name: str) -> Any: + if col_name not in self.list_columns: + # Handle AssociationProxy etc, or anything that isn't a "real" column + return None + return super().get_col_default(col_name) + filter_converter_class = AirflowFilterConverter diff --git a/airflow/www/views.py b/airflow/www/views.py index 4790c303e9020..7b4d42a82cb9e 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -18,7 +18,6 @@ # import collections import copy -import itertools import json import logging import math @@ -807,7 +806,7 @@ def task_stats(self, session=None): filter_dag_ids = allowed_dag_ids running_dag_run_query_result = ( - session.query(DagRun.dag_id, DagRun.execution_date) + session.query(DagRun.dag_id, DagRun.run_id) .join(DagModel, DagModel.dag_id == DagRun.dag_id) .filter(DagRun.state == State.RUNNING, DagModel.is_active) ) @@ -823,7 +822,7 @@ def task_stats(self, session=None): running_dag_run_query_result, and_( running_dag_run_query_result.c.dag_id == TaskInstance.dag_id, - running_dag_run_query_result.c.execution_date == TaskInstance.execution_date, + running_dag_run_query_result.c.run_id == TaskInstance.run_id, ), ) @@ -841,14 +840,16 @@ def task_stats(self, session=None): # Select all task_instances from active dag_runs. # If no dag_run is active, return task instances from most recent dag_run. - last_task_instance_query_result = session.query( - TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state') - ).join( - last_dag_run, - and_( - last_dag_run.c.dag_id == TaskInstance.dag_id, - last_dag_run.c.execution_date == TaskInstance.execution_date, - ), + last_task_instance_query_result = ( + session.query(TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state')) + .join(TaskInstance.dag_run) + .join( + last_dag_run, + and_( + last_dag_run.c.dag_id == TaskInstance.dag_id, + last_dag_run.c.execution_date == DagRun.execution_date, + ), + ) ) final_task_instance_query_result = union_all( @@ -1036,7 +1037,8 @@ def dag_details(self, session=None): ] ) @action_logging - def rendered_templates(self): + @provide_session + def rendered_templates(self, session): """Get rendered Dag.""" dag_id = request.args.get('dag_id') task_id = request.args.get('task_id') @@ -1047,11 +1049,13 @@ def rendered_templates(self): logging.info("Retrieving rendered templates.") dag = current_app.dag_bag.get_dag(dag_id) + dag_run = dag.get_dagrun(execution_date=dttm, session=session) task = copy.copy(dag.get_task(task_id)) - ti = models.TaskInstance(task=task, execution_date=dttm) + ti = dag_run.get_task_instance(task_id=task.task_id, session=session) + ti.refresh_from_task(task) try: - ti.get_rendered_template_fields() + ti.get_rendered_template_fields(session=session) except AirflowException as e: msg = "Error rendering template: " + escape(e) if e.__cause__: @@ -1125,7 +1129,8 @@ def rendered_k8s(self): logging.info("Retrieving rendered templates.") dag = current_app.dag_bag.get_dag(dag_id) task = dag.get_task(task_id) - ti = models.TaskInstance(task=task, execution_date=dttm) + dag_run = dag.get_dagrun(execution_date=dttm) + ti = dag_run.get_task_instance(task_id=task.task_id) pod_spec = None try: @@ -1342,7 +1347,8 @@ def redirect_to_external_log(self, session=None): ] ) @action_logging - def task(self): + @provide_session + def task(self, session): """Retrieve task.""" dag_id = request.args.get('dag_id') task_id = request.args.get('task_id') @@ -1357,30 +1363,52 @@ def task(self): return redirect(url_for('Airflow.index')) task = copy.copy(dag.get_task(task_id)) task.resolve_template_files() - ti = TaskInstance(task=task, execution_date=dttm) - ti.refresh_from_db() - - ti_attrs = [] - for attr_name in dir(ti): - if not attr_name.startswith('_'): - attr = getattr(ti, attr_name) - if type(attr) != type(self.task): # noqa - ti_attrs.append((attr_name, str(attr))) - task_attrs = [] - for attr_name in dir(task): - if not attr_name.startswith('_'): - attr = getattr(task, attr_name) + ti = ( + session.query(TaskInstance) + .options( + # HACK: Eager-load relationships. This is needed because + # multiple properties mis-use provide_session() that destroys + # the session object ti is bounded to. + joinedload(TaskInstance.queued_by_job, innerjoin=False), + joinedload(TaskInstance.trigger, innerjoin=False), + ) + .join(TaskInstance.dag_run) + .filter( + DagRun.execution_date == dttm, + TaskInstance.dag_id == dag_id, + TaskInstance.task_id == task_id, + ) + .one() + ) + ti.refresh_from_task(task) - if type(attr) != type(self.task) and attr_name not in wwwutils.get_attr_renderer(): # noqa - task_attrs.append((attr_name, str(attr))) + ti_attrs = [ + (attr_name, attr) + for attr_name, attr in ( + (attr_name, getattr(ti, attr_name)) for attr_name in dir(ti) if not attr_name.startswith("_") + ) + if not callable(attr) + ] + ti_attrs = sorted(ti_attrs) + + attr_renderers = wwwutils.get_attr_renderer() + task_attrs = [ + (attr_name, attr) + for attr_name, attr in ( + (attr_name, getattr(task, attr_name)) + for attr_name in dir(task) + if not attr_name.startswith("_") and attr_name not in attr_renderers + ) + if not callable(attr) + ] # Color coding the special attributes that are code - special_attrs_rendered = {} - for attr_name in wwwutils.get_attr_renderer(): - if getattr(task, attr_name, None) is not None: - source = getattr(task, attr_name) - special_attrs_rendered[attr_name] = wwwutils.get_attr_renderer()[attr_name](source) + special_attrs_rendered = { + attr_name: renderer(getattr(task, attr_name)) + for attr_name, renderer in attr_renderers.items() + if hasattr(task, attr_name) + } no_failed_deps_result = [ ( @@ -1514,8 +1542,9 @@ def run(self): flash("Only works with the Celery or Kubernetes executors, sorry", "error") return redirect(origin) - ti = models.TaskInstance(task=task, execution_date=execution_date) - ti.refresh_from_db() + dag_run = dag.get_dagrun(execution_date=execution_date) + ti = dag_run.get_task_instance(task_id=task.task_id) + ti.refresh_from_task(task) # Make sure the task instance can be run dep_context = DepContext( @@ -2089,14 +2118,20 @@ def success(self): State.SUCCESS, ) - def _get_tree_data(self, dag_runs: Iterable[DagRun], dag: DAG, base_date: DateTime): + def _get_tree_data( + self, + dag_runs: Iterable[DagRun], + dag: DAG, + base_date: DateTime, + session: settings.Session, + ): """Returns formatted dag_runs for Tree view""" dates = sorted(dag_runs.keys()) min_date = min(dag_runs, default=None) task_instances = { (ti.task_id, ti.execution_date): ti - for ti in dag.get_task_instances(start_date=min_date, end_date=base_date) + for ti in dag.get_task_instances(start_date=min_date, end_date=base_date, session=session) } expanded = set() @@ -2239,7 +2274,7 @@ def tree(self, session=None): else: external_log_name = None - data = self._get_tree_data(dag_runs, dag, base_date) + data = self._get_tree_data(dag_runs, dag, base_date, session=session) # avoid spaces to reduce payload size data = htmlsafe_json_dumps(data, separators=(',', ':')) @@ -2766,23 +2801,22 @@ def gantt(self, session=None): form = DateTimeWithNumRunsWithDagRunsForm(data=dt_nr_dr_data) form.execution_date.choices = dt_nr_dr_data['dr_choices'] - tis = [ti for ti in dag.get_task_instances(dttm, dttm) if ti.start_date and ti.state] - tis = sorted(tis, key=lambda ti: ti.start_date) - ti_fails = list( - itertools.chain( - *( - ( - session.query(TaskFail) - .filter( - TaskFail.dag_id == ti.dag_id, - TaskFail.task_id == ti.task_id, - TaskFail.execution_date == ti.execution_date, - ) - .all() - ) - for ti in tis - ) + tis = ( + session.query(TaskInstance) + .join(TaskInstance.dag_run) + .filter( + DagRun.execution_date == dttm, + TaskInstance.dag_id == dag_id, + TaskInstance.start_date.isnot(None), + TaskInstance.state.isnot(None), ) + .order_by(TaskInstance.start_date) + ) + + ti_fails = ( + session.query(TaskFail) + .join(DagRun, DagRun.execution_date == TaskFail.execution_date) + .filter(DagRun.execution_date == dttm, TaskFail.dag_id == dag_id) ) tasks = [] @@ -2818,10 +2852,11 @@ def gantt(self, session=None): task_dict['extraLinks'] = task.extra_links tasks.append(task_dict) + task_names = [ti.task_id for ti in tis] data = { - 'taskNames': [ti.task_id for ti in tis], + 'taskNames': task_names, 'tasks': tasks, - 'height': len(tis) * 25 + 25, + 'height': len(task_names) * 25 + 25, } session.commit() @@ -2959,9 +2994,8 @@ def tree_data(self): .limit(num_runs) .all() ) - dag_runs = {dr.execution_date: alchemy_to_dict(dr) for dr in dag_runs} - - tree_data = self._get_tree_data(dag_runs, dag, base_date) + dag_runs = {dr.execution_date: alchemy_to_dict(dr) for dr in dag_runs} + tree_data = self._get_tree_data(dag_runs, dag, base_date, session=session) # avoid spaces to reduce payload size return htmlsafe_json_dumps(tree_data, separators=(',', ':')) @@ -3949,8 +3983,9 @@ class TaskRescheduleModelView(AirflowModelView): list_columns = [ 'id', 'dag_id', + 'run_id', + 'dag_run.execution_date', 'task_id', - 'execution_date', 'try_number', 'start_date', 'end_date', @@ -3958,7 +3993,19 @@ class TaskRescheduleModelView(AirflowModelView): 'reschedule_date', ] - search_columns = ['dag_id', 'task_id', 'execution_date', 'start_date', 'end_date', 'reschedule_date'] + label_columns = { + 'dag_run.execution_date': 'Execution Date', + } + + search_columns = [ + 'dag_id', + 'task_id', + 'run_id', + 'execution_date', + 'start_date', + 'end_date', + 'reschedule_date', + ] base_order = ('id', 'desc') @@ -3977,7 +4024,7 @@ def duration_f(self): 'task_id': wwwutils.task_instance_link, 'start_date': wwwutils.datetime_f('start_date'), 'end_date': wwwutils.datetime_f('end_date'), - 'execution_date': wwwutils.datetime_f('execution_date'), + 'dag_run.execution_date': wwwutils.datetime_f('dag_run.execution_date'), 'reschedule_date': wwwutils.datetime_f('reschedule_date'), 'duration': duration_f, } @@ -4013,7 +4060,8 @@ class TaskInstanceModelView(AirflowModelView): 'state', 'dag_id', 'task_id', - 'execution_date', + 'run_id', + 'dag_run.execution_date', 'operator', 'start_date', 'end_date', @@ -4035,10 +4083,15 @@ class TaskInstanceModelView(AirflowModelView): item for item in list_columns if item not in ['try_number', 'log_url', 'external_executor_id'] ] + label_columns = { + 'dag_run.execution_date': 'Execution Date', + } + search_columns = [ 'state', 'dag_id', 'task_id', + 'run_id', 'execution_date', 'hostname', 'queue', @@ -4051,9 +4104,6 @@ class TaskInstanceModelView(AirflowModelView): edit_columns = [ 'state', - 'dag_id', - 'task_id', - 'execution_date', 'start_date', 'end_date', ] @@ -4084,9 +4134,10 @@ def duration_f(self): formatters_columns = { 'log_url': log_url_formatter, 'task_id': wwwutils.task_instance_link, + 'run_id': wwwutils.dag_run_link, 'hostname': wwwutils.nobr_f('hostname'), 'state': wwwutils.state_f, - 'execution_date': wwwutils.datetime_f('execution_date'), + 'dag_run.execution_date': wwwutils.datetime_f('dag_run.execution_date'), 'start_date': wwwutils.datetime_f('start_date'), 'end_date': wwwutils.datetime_f('end_date'), 'queued_dttm': wwwutils.datetime_f('queued_dttm'), diff --git a/docs/apache-airflow/concepts/scheduler.rst b/docs/apache-airflow/concepts/scheduler.rst index 0a1079e4a290c..3a54b0269746c 100644 --- a/docs/apache-airflow/concepts/scheduler.rst +++ b/docs/apache-airflow/concepts/scheduler.rst @@ -174,17 +174,6 @@ The following config settings can be used to control aspects of the Scheduler HA this, so this should be set to match the same period as your statsd roll-up period. -- :ref:`config:scheduler__clean_tis_without_dagrun_interval` - - How often should each scheduler run a check to "clean up" TaskInstance rows - that are found to no longer have a matching DagRun row. - - In normal operation the scheduler won't do this, it is only possible to do - this by deleting rows via the UI, or directly in the DB. You can set this - lower if this check is not important to you -- tasks will be left in what - ever state they are until the cleanup happens, at which point they will be - set to failed. - - :ref:`config:scheduler__orphaned_tasks_check_interval` How often (in seconds) should the scheduler check for orphaned tasks or dead diff --git a/docs/apache-airflow/logging-monitoring/metrics.rst b/docs/apache-airflow/logging-monitoring/metrics.rst index 8e5d46b4b0a64..c1bc249494594 100644 --- a/docs/apache-airflow/logging-monitoring/metrics.rst +++ b/docs/apache-airflow/logging-monitoring/metrics.rst @@ -123,7 +123,6 @@ Name Description ``dag_processing.total_parse_time`` Seconds taken to scan and import all DAG files once ``dag_processing.last_run.seconds_ago.`` Seconds since ```` was last processed ``dag_processing.processor_timeouts`` Number of file processors that have been killed due to taking too long -``scheduler.tasks.without_dagrun`` Number of tasks without DagRuns or with DagRuns not in Running state ``scheduler.tasks.running`` Number of tasks running in executor ``scheduler.tasks.starving`` Number of tasks that cannot be scheduled because of no open slot in pool ``scheduler.tasks.executable`` Number of tasks that are ready for execution (set to queued) diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index 1f387b6244ce7..052d61a08f590 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -23,7 +23,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ -| ``142555e44c17`` (head) | ``54bebd308c5f`` | | Add ``data_interval_start`` and ``data_interval_end`` to ``DagRun`` | +| ``7b2661a43ba3`` (head) | ``142555e44c17`` | | Change TaskInstance and TaskReschedule tables from execution_date to run_id. | ++--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ +| ``142555e44c17`` | ``54bebd308c5f`` | | Add ``data_interval_start`` and ``data_interval_end`` to ``DagRun`` | +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ | ``54bebd308c5f`` | ``30867afad44a`` | | Adds ``trigger`` table and deferrable operator columns to task instance | +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index 604e5f2da33e2..52d7f3ab7bec7 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -34,7 +34,7 @@ from airflow.exceptions import AirflowException from airflow.kubernetes import kube_client from airflow.kubernetes.secret import Secret -from airflow.models import DAG, TaskInstance +from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.utils.pod_launcher import PodLauncher from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults @@ -47,7 +47,9 @@ def create_context(task): dag = DAG(dag_id="dag") tzinfo = pendulum.timezone("Europe/Amsterdam") execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo) - task_instance = TaskInstance(task=task, execution_date=execution_date) + dag_run = DagRun(dag_id=dag.dag_id, execution_date=execution_date) + task_instance = TaskInstance(task=task) + task_instance.dag_run = dag_run task_instance.xcom_push = mock.Mock() return { "dag": dag, diff --git a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py index 0b0eabebf8f3a..683400886d5fa 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py +++ b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py @@ -34,7 +34,7 @@ from airflow.kubernetes.secret import Secret from airflow.kubernetes.volume import Volume from airflow.kubernetes.volume_mount import VolumeMount -from airflow.models import DAG, TaskInstance +from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.utils.pod_launcher import PodLauncher from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults @@ -50,7 +50,9 @@ def create_context(task): dag = DAG(dag_id="dag") tzinfo = pendulum.timezone("Europe/Amsterdam") execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo) - task_instance = TaskInstance(task=task, execution_date=execution_date) + dag_run = DagRun(dag_id=dag.dag_id, execution_date=execution_date) + task_instance = TaskInstance(task=task) + task_instance.dag_run = dag_run task_instance.xcom_push = mock.Mock() return { "dag": dag, diff --git a/tests/api/common/experimental/test_delete_dag.py b/tests/api/common/experimental/test_delete_dag.py index 58bcd37d9299d..5984cd2b14f0f 100644 --- a/tests/api/common/experimental/test_delete_dag.py +++ b/tests/api/common/experimental/test_delete_dag.py @@ -16,15 +16,13 @@ # specific language governing permissions and limitations # under the License. -import unittest import pytest -from airflow import models, settings +from airflow import models from airflow.api.common.experimental.delete_dag import delete_dag from airflow.exceptions import AirflowException, DagNotFound from airflow.operators.dummy import DummyOperator -from airflow.utils import timezone from airflow.utils.dates import days_ago from airflow.utils.session import create_session from airflow.utils.state import State @@ -40,15 +38,7 @@ IE = models.ImportError -class TestDeleteDAGCatchError(unittest.TestCase): - def setUp(self): - self.dagbag = models.DagBag(include_examples=True) - self.dag_id = 'example_bash_operator' - self.dag = self.dagbag.dags[self.dag_id] - - def tearDown(self): - self.dag.clear() - +class TestDeleteDAGCatchError: def test_delete_dag_non_existent_dag(self): with pytest.raises(DagNotFound): delete_dag("non-existent DAG") @@ -63,21 +53,17 @@ def teardown_method(self): clear_db_dags() clear_db_runs() - def test_delete_dag_running_taskinstances(self, create_dummy_dag): + def test_delete_dag_running_taskinstances(self, session, create_task_instance): dag_id = 'test-dag' - _, task = create_dummy_dag(dag_id) + ti = create_task_instance(dag_id=dag_id, session=session) - ti = TI(task, execution_date=timezone.utcnow()) - ti.refresh_from_db() - session = settings.Session() ti.state = State.RUNNING - session.merge(ti) session.commit() with pytest.raises(AirflowException): delete_dag(dag_id) -class TestDeleteDAGSuccessfulDelete(unittest.TestCase): +class TestDeleteDAGSuccessfulDelete: dag_file_path = "/usr/local/airflow/dags/test_dag_8.py" key = "test_dag_id" @@ -94,8 +80,10 @@ def setup_dag_models(self, for_sub_dag=False): test_date = days_ago(1) with create_session() as session: session.add(DM(dag_id=self.key, fileloc=self.dag_file_path, is_subdag=for_sub_dag)) - session.add(DR(dag_id=self.key, run_type=DagRunType.MANUAL)) - session.add(TI(task=task, execution_date=test_date, state=State.SUCCESS)) + dr = DR(dag_id=self.key, run_type=DagRunType.MANUAL, run_id="test", execution_date=test_date) + ti = TI(task=task, state=State.SUCCESS) + ti.dag_run = dr + session.add_all((dr, ti)) # flush to ensure task instance if written before # task reschedule because of FK constraint session.flush() @@ -111,8 +99,8 @@ def setup_dag_models(self, for_sub_dag=False): session.add(TF(task=task, execution_date=test_date, start_date=test_date, end_date=test_date)) session.add( TR( - task=task, - execution_date=test_date, + task=ti.task, + run_id=ti.run_id, start_date=test_date, end_date=test_date, try_number=1, @@ -127,7 +115,7 @@ def setup_dag_models(self, for_sub_dag=False): ) ) - def tearDown(self): + def teardown_method(self): with create_session() as session: session.query(TR).filter(TR.dag_id == self.key).delete() session.query(TF).filter(TF.dag_id == self.key).delete() diff --git a/tests/api/common/experimental/test_mark_tasks.py b/tests/api/common/experimental/test_mark_tasks.py index 49008d386ba20..e43ac4a3302bf 100644 --- a/tests/api/common/experimental/test_mark_tasks.py +++ b/tests/api/common/experimental/test_mark_tasks.py @@ -20,6 +20,7 @@ from datetime import timedelta import pytest +from sqlalchemy.orm import eagerload from airflow import models from airflow.api.common.experimental.mark_tasks import ( @@ -40,11 +41,20 @@ DEV_NULL = "/dev/null" -class TestMarkTasks(unittest.TestCase): +@pytest.fixture(scope="module") +def dagbag(): + from airflow.models.dagbag import DagBag + + # Ensure the DAGs we are looking at from the DB are up-to-date + non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False) + non_serialized_dagbag.sync_to_db() + return DagBag(read_dags_from_db=True) + + +class TestMarkTasks: + @pytest.fixture(scope="class", autouse=True, name="create_dags") @classmethod - def setUpClass(cls): - models.DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() - dagbag = models.DagBag(include_examples=False, read_dags_from_db=True) + def create_dags(cls, dagbag): cls.dag1 = dagbag.get_dag('miscellaneous_test_dag') cls.dag2 = dagbag.get_dag('example_subdag_operator') cls.dag3 = dagbag.get_dag('example_trigger_target_dag') @@ -56,7 +66,9 @@ def setUpClass(cls): start_date3 + timedelta(days=2), ] - def setUp(self): + @pytest.fixture(autouse=True) + def setup(self): + clear_db_runs() drs = _create_dagruns( self.dag1, self.execution_dates, state=State.RUNNING, run_type=DagRunType.SCHEDULED @@ -77,24 +89,35 @@ def setUp(self): for dr in drs: dr.dag = self.dag3 - def tearDown(self): + yield + clear_db_runs() @staticmethod def snapshot_state(dag, execution_dates): TI = models.TaskInstance + DR = models.DagRun with create_session() as session: return ( session.query(TI) - .filter(TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates)) + .join(TI.dag_run) + .options(eagerload(TI.dag_run)) + .filter(TI.dag_id == dag.dag_id, DR.execution_date.in_(execution_dates)) .all() ) @provide_session def verify_state(self, dag, task_ids, execution_dates, state, old_tis, session=None): TI = models.TaskInstance - - tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates)).all() + DR = models.DagRun + + tis = ( + session.query(TI) + .join(TI.dag_run) + .options(eagerload(TI.dag_run)) + .filter(TI.dag_id == dag.dag_id, DR.execution_date.in_(execution_dates)) + .all() + ) assert len(tis) > 0 diff --git a/tests/api_connexion/conftest.py b/tests/api_connexion/conftest.py index ead538bd80e6a..cc92733642d46 100644 --- a/tests/api_connexion/conftest.py +++ b/tests/api_connexion/conftest.py @@ -40,3 +40,11 @@ def session(): with create_session() as session: yield session + + +@pytest.fixture(scope="session") +def dagbag(): + from airflow.models import DagBag + + DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() + return DagBag(include_examples=True, read_dags_from_db=True) diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index 9c5f593a09b30..17a2553ed82e9 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -122,6 +122,7 @@ def _create_dag_models(self, count, session=None): fileloc=f"/tmp/dag_{num}.py", schedule_interval="2 2 * * *", is_active=True, + is_paused=False, ) session.add(dag_model) @@ -162,6 +163,7 @@ def test_should_respond_200_with_schedule_interval_none(self, session): dag_id="TEST_DAG_1", fileloc="/tmp/dag_1.py", schedule_interval=None, + is_paused=False, ) session.add(dag_model) session.commit() diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py b/tests/api_connexion/endpoints/test_event_log_endpoint.py index f1025e64cfaa5..65daae85f84d7 100644 --- a/tests/api_connexion/endpoints/test_event_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py @@ -16,15 +16,11 @@ # under the License. import pytest -from parameterized import parameterized -from airflow import DAG from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.models import Log, TaskInstance -from airflow.operators.dummy import DummyOperator +from airflow.models import Log from airflow.security import permissions from airflow.utils import timezone -from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_logs @@ -47,43 +43,56 @@ def configured_app(minimal_app_for_api): delete_user(app, username="test_no_permissions") # type: ignore +@pytest.fixture +def task_instance(session, create_task_instance, request): + return create_task_instance( + session=session, + dag_id="TEST_DAG_ID", + task_id="TEST_TASK_ID", + execution_date=request.instance.default_time, + ) + + +@pytest.fixture() +def log_model(create_log_model, request): + return create_log_model( + event="TEST_EVENT", + when=request.instance.default_time, + ) + + +@pytest.fixture +def create_log_model(create_task_instance, task_instance, session, request): + def maker(event, when, **kwargs): + log_model = Log( + event=event, + task_instance=task_instance, + **kwargs, + ) + log_model.dttm = when + + session.add(log_model) + session.flush() + return log_model + + return maker + + class TestEventLogEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: self.app = configured_app self.client = self.app.test_client() # type:ignore clear_db_logs() - self.default_time = "2020-06-10T20:00:00+00:00" - self.default_time_2 = '2020-06-11T07:00:00+00:00' + self.default_time = timezone.parse("2020-06-10T20:00:00+00:00") + self.default_time_2 = timezone.parse('2020-06-11T07:00:00+00:00') def teardown_method(self) -> None: clear_db_logs() - def _create_task_instance(self): - dag = DAG( - 'TEST_DAG_ID', - start_date=timezone.parse(self.default_time), - end_date=timezone.parse(self.default_time), - ) - op1 = DummyOperator( - task_id="TEST_TASK_ID", - owner="airflow", - ) - dag.add_task(op1) - ti = TaskInstance(task=op1, execution_date=timezone.parse(self.default_time)) - return ti - class TestGetEventLog(TestEventLogEndpoint): - @provide_session - def test_should_respond_200(self, session): - log_model = Log( - event='TEST_EVENT', - task_instance=self._create_task_instance(), - ) - log_model.dttm = timezone.parse(self.default_time) - session.add(log_model) - session.commit() + def test_should_respond_200(self, log_model): event_log_id = log_model.id response = self.client.get( f"/api/v1/eventLogs/{event_log_id}", environ_overrides={'REMOTE_USER': "test"} @@ -94,9 +103,9 @@ def test_should_respond_200(self, session): "event": "TEST_EVENT", "dag_id": "TEST_DAG_ID", "task_id": "TEST_TASK_ID", - "execution_date": self.default_time, + "execution_date": self.default_time.isoformat(), "owner": 'airflow', - "when": self.default_time, + "when": self.default_time.isoformat(), "extra": None, } @@ -110,15 +119,7 @@ def test_should_respond_404(self): 'type': EXCEPTIONS_LINK_MAP[404], } == response.json - @provide_session - def test_should_raises_401_unauthenticated(self, session): - log_model = Log( - event='TEST_EVENT', - task_instance=self._create_task_instance(), - ) - log_model.dttm = timezone.parse(self.default_time) - session.add(log_model) - session.commit() + def test_should_raises_401_unauthenticated(self, log_model): event_log_id = log_model.id response = self.client.get(f"/api/v1/eventLogs/{event_log_id}") @@ -133,21 +134,14 @@ def test_should_raise_403_forbidden(self): class TestGetEventLogs(TestEventLogEndpoint): - def test_should_respond_200(self, session): - log_model_1 = Log( - event='TEST_EVENT_1', - task_instance=self._create_task_instance(), - ) - log_model_2 = Log( - event='TEST_EVENT_2', - task_instance=self._create_task_instance(), - ) + def test_should_respond_200(self, session, create_log_model): + log_model_1 = create_log_model(event='TEST_EVENT_1', when=self.default_time) + log_model_2 = create_log_model(event='TEST_EVENT_2', when=self.default_time_2) log_model_3 = Log(event="cli_scheduler", owner='root', extra='{"host_name": "e24b454f002a"}') - log_model_1.dttm = timezone.parse(self.default_time) - log_model_2.dttm = timezone.parse(self.default_time_2) - log_model_3.dttm = timezone.parse(self.default_time_2) - session.add_all([log_model_1, log_model_2, log_model_3]) - session.commit() + log_model_3.dttm = self.default_time_2 + + session.add(log_model_3) + session.flush() response = self.client.get("/api/v1/eventLogs", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 200 assert response.json == { @@ -157,9 +151,9 @@ def test_should_respond_200(self, session): "event": "TEST_EVENT_1", "dag_id": "TEST_DAG_ID", "task_id": "TEST_TASK_ID", - "execution_date": self.default_time, + "execution_date": self.default_time.isoformat(), "owner": 'airflow', - "when": self.default_time, + "when": self.default_time.isoformat(), "extra": None, }, { @@ -167,9 +161,9 @@ def test_should_respond_200(self, session): "event": "TEST_EVENT_2", "dag_id": "TEST_DAG_ID", "task_id": "TEST_TASK_ID", - "execution_date": self.default_time, + "execution_date": self.default_time.isoformat(), "owner": 'airflow', - "when": self.default_time_2, + "when": self.default_time_2.isoformat(), "extra": None, }, { @@ -179,25 +173,20 @@ def test_should_respond_200(self, session): "task_id": None, "execution_date": None, "owner": 'root', - "when": self.default_time_2, + "when": self.default_time_2.isoformat(), "extra": '{"host_name": "e24b454f002a"}', }, ], "total_entries": 3, } - def test_order_eventlogs_by_owner(self, session): - log_model_1 = Log( - event='TEST_EVENT_1', - task_instance=self._create_task_instance(), - ) - log_model_2 = Log(event='TEST_EVENT_2', task_instance=self._create_task_instance(), owner="zsh") + def test_order_eventlogs_by_owner(self, create_log_model, session): + log_model_1 = create_log_model(event="TEST_EVENT_1", when=self.default_time) + log_model_2 = create_log_model(event="TEST_EVENT_2", when=self.default_time_2, owner='zsh') log_model_3 = Log(event="cli_scheduler", owner='root', extra='{"host_name": "e24b454f002a"}') - log_model_1.dttm = timezone.parse(self.default_time) - log_model_2.dttm = timezone.parse(self.default_time_2) - log_model_3.dttm = timezone.parse(self.default_time_2) - session.add_all([log_model_1, log_model_2, log_model_3]) - session.commit() + log_model_3.dttm = self.default_time_2 + session.add(log_model_3) + session.flush() response = self.client.get( "/api/v1/eventLogs?order_by=-owner", environ_overrides={'REMOTE_USER': "test"} ) @@ -209,9 +198,9 @@ def test_order_eventlogs_by_owner(self, session): "event": "TEST_EVENT_2", "dag_id": "TEST_DAG_ID", "task_id": "TEST_TASK_ID", - "execution_date": self.default_time, + "execution_date": self.default_time.isoformat(), "owner": 'zsh', # Order by name, sort order is descending(-) - "when": self.default_time_2, + "when": self.default_time_2.isoformat(), "extra": None, }, { @@ -221,7 +210,7 @@ def test_order_eventlogs_by_owner(self, session): "task_id": None, "execution_date": None, "owner": 'root', - "when": self.default_time_2, + "when": self.default_time_2.isoformat(), "extra": '{"host_name": "e24b454f002a"}', }, { @@ -229,37 +218,24 @@ def test_order_eventlogs_by_owner(self, session): "event": "TEST_EVENT_1", "dag_id": "TEST_DAG_ID", "task_id": "TEST_TASK_ID", - "execution_date": self.default_time, + "execution_date": self.default_time.isoformat(), "owner": 'airflow', - "when": self.default_time, + "when": self.default_time.isoformat(), "extra": None, }, ], "total_entries": 3, } - @provide_session - def test_should_raises_401_unauthenticated(self, session): - log_model_1 = Log( - event='TEST_EVENT_1', - task_instance=self._create_task_instance(), - ) - log_model_2 = Log( - event='TEST_EVENT_2', - task_instance=self._create_task_instance(), - ) - log_model_1.dttm = timezone.parse(self.default_time) - log_model_2.dttm = timezone.parse(self.default_time_2) - session.add_all([log_model_1, log_model_2]) - session.commit() - + def test_should_raises_401_unauthenticated(self, log_model): response = self.client.get("/api/v1/eventLogs") assert_401(response) class TestGetEventLogPagination(TestEventLogEndpoint): - @parameterized.expand( + @pytest.mark.parametrize( + ("url", "expected_events"), [ ("api/v1/eventLogs?limit=1", ["TEST_EVENT_1"]), ("api/v1/eventLogs?limit=2", ["TEST_EVENT_1", "TEST_EVENT_2"]), @@ -294,11 +270,10 @@ class TestGetEventLogPagination(TestEventLogEndpoint): "api/v1/eventLogs?limit=2&offset=2", ["TEST_EVENT_3", "TEST_EVENT_4"], ), - ] + ], ) - @provide_session - def test_handle_limit_and_offset(self, url, expected_events, session): - log_models = self._create_event_logs(10) + def test_handle_limit_and_offset(self, url, expected_events, task_instance, session): + log_models = self._create_event_logs(task_instance, 10) session.add_all(log_models) session.commit() @@ -309,11 +284,10 @@ def test_handle_limit_and_offset(self, url, expected_events, session): events = [event_log["event"] for event_log in response.json["event_logs"]] assert events == expected_events - @provide_session - def test_should_respect_page_size_limit_default(self, session): - log_models = self._create_event_logs(200) + def test_should_respect_page_size_limit_default(self, task_instance, session): + log_models = self._create_event_logs(task_instance, 200) session.add_all(log_models) - session.commit() + session.flush() response = self.client.get("/api/v1/eventLogs", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 200 @@ -321,10 +295,10 @@ def test_should_respect_page_size_limit_default(self, session): assert response.json["total_entries"] == 200 assert len(response.json["event_logs"]) == 100 # default 100 - def test_should_raise_400_for_invalid_order_by_name(self, session): - log_models = self._create_event_logs(200) + def test_should_raise_400_for_invalid_order_by_name(self, task_instance, session): + log_models = self._create_event_logs(task_instance, 200) session.add_all(log_models) - session.commit() + session.flush() response = self.client.get( "/api/v1/eventLogs?order_by=invalid", environ_overrides={'REMOTE_USER': "test"} @@ -333,19 +307,15 @@ def test_should_raise_400_for_invalid_order_by_name(self, session): msg = "Ordering with 'invalid' is disallowed or the attribute does not exist on the model" assert response.json['detail'] == msg - @provide_session @conf_vars({("api", "maximum_page_limit"): "150"}) - def test_should_return_conf_max_if_req_max_above_conf(self, session): - log_models = self._create_event_logs(200) + def test_should_return_conf_max_if_req_max_above_conf(self, task_instance, session): + log_models = self._create_event_logs(task_instance, 200) session.add_all(log_models) - session.commit() + session.flush() response = self.client.get("/api/v1/eventLogs?limit=180", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 200 assert len(response.json['event_logs']) == 150 - def _create_event_logs(self, count): - return [ - Log(event="TEST_EVENT_" + str(i), task_instance=self._create_task_instance()) - for i in range(1, count + 1) - ] + def _create_event_logs(self, task_instance, count): + return [Log(event="TEST_EVENT_" + str(i), task_instance=task_instance) for i in range(1, count + 1)] diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py index 15caf1995cd89..87963c0b1ea6a 100644 --- a/tests/api_connexion/endpoints/test_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_log_endpoint.py @@ -26,11 +26,9 @@ from airflow import DAG from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG -from airflow.models import DagRun, TaskInstance from airflow.operators.dummy import DummyOperator from airflow.security import permissions from airflow.utils import timezone -from airflow.utils.session import create_session from airflow.utils.types import DagRunType from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_runs @@ -66,24 +64,25 @@ class TestGetLog: default_time = "2020-06-10T20:00:00+00:00" @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app, configure_loggers) -> None: + def setup_attrs(self, configured_app, configure_loggers, dag_maker, session) -> None: self.app = configured_app self.client = self.app.test_client() # Make sure that the configure_logging is not cached self.old_modules = dict(sys.modules) - self._prepare_db() - def _create_dagrun(self, session): - dagrun_model = DagRun( - dag_id=self.DAG_ID, + with dag_maker(self.DAG_ID, start_date=timezone.parse(self.default_time), session=session) as dag: + DummyOperator(task_id=self.TASK_ID) + dr = dag_maker.create_dagrun( run_id='TEST_DAG_RUN_ID', run_type=DagRunType.MANUAL, execution_date=timezone.parse(self.default_time), start_date=timezone.parse(self.default_time), - external_trigger=True, ) - session.add(dagrun_model) - session.commit() + + configured_app.dag_bag.bag_dag(dag, root_dag=dag) + + self.ti = dr.task_instances[0] + self.ti.try_number = 1 @pytest.fixture def configure_loggers(self, tmp_path): @@ -109,25 +108,10 @@ def configure_loggers(self, tmp_path): logging.config.dictConfig(DEFAULT_LOGGING_CONFIG) - def _prepare_db(self): - dagbag = self.app.dag_bag - dag = DAG(self.DAG_ID, start_date=timezone.parse(self.default_time)) - dag.sync_to_db() - dagbag.dags.pop(self.DAG_ID, None) - dagbag.bag_dag(dag=dag, root_dag=dag) - with create_session() as session: - self.ti = TaskInstance( - task=DummyOperator(task_id=self.TASK_ID, dag=dag), - execution_date=timezone.parse(self.default_time), - ) - self.ti.try_number = 1 - session.merge(self.ti) - def teardown_method(self): clear_db_runs() def test_should_respond_200_json(self, session): - self._create_dagrun(session) key = self.app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": False}) @@ -149,7 +133,6 @@ def test_should_respond_200_json(self, session): assert 200 == response.status_code def test_should_respond_200_text_plain(self, session): - self._create_dagrun(session) key = self.app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) @@ -170,8 +153,6 @@ def test_should_respond_200_text_plain(self, session): ) def test_get_logs_of_removed_task(self, session): - self._create_dagrun(session) - # Recreate DAG without tasks dagbag = self.app.dag_bag dag = DAG(self.DAG_ID, start_date=timezone.parse(self.default_time)) @@ -198,7 +179,6 @@ def test_get_logs_of_removed_task(self, session): ) def test_get_logs_response_with_ti_equal_to_none(self, session): - self._create_dagrun(session) key = self.app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) @@ -208,11 +188,15 @@ def test_get_logs_response_with_ti_equal_to_none(self, session): f"taskInstances/Invalid-Task-ID/logs/1?token={token}", environ_overrides={'REMOTE_USER': "test"}, ) - assert response.status_code == 400 - assert response.json['detail'] == "Task instance did not exist in the DB" + assert response.status_code == 404 + assert response.json == { + 'detail': None, + 'status': 404, + 'title': "TaskInstance not found", + 'type': EXCEPTIONS_LINK_MAP[404], + } def test_get_logs_with_metadata_as_download_large_file(self, session): - self._create_dagrun(session) with mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") as read_mock: first_return = ([[('', '1st line')]], [{}]) second_return = ([[('', '2nd line')]], [{'end_of_log': False}]) @@ -251,7 +235,6 @@ def test_get_logs_for_handler_without_read_method(self, mock_log_reader): assert 'Task log handler does not support read logs.' in response.data.decode('utf-8') def test_bad_signature_raises(self, session): - self._create_dagrun(session) token = {"download_logs": False} response = self.client.get( @@ -269,15 +252,16 @@ def test_bad_signature_raises(self, session): def test_raises_404_for_invalid_dag_run_id(self): response = self.client.get( - f"api/v1/dags/{self.DAG_ID}/dagRuns/TEST_DAG_RUN/" # invalid dagrun_id + f"api/v1/dags/{self.DAG_ID}/dagRuns/NO_DAG_RUN/" # invalid dagrun_id f"taskInstances/{self.TASK_ID}/logs/1?", headers={'Accept': 'application/json'}, environ_overrides={'REMOTE_USER': "test"}, ) + assert response.status_code == 404 assert response.json == { 'detail': None, 'status': 404, - 'title': "DAG Run not found", + 'title': "TaskInstance not found", 'type': EXCEPTIONS_LINK_MAP[404], } diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index e359cd4131ce2..7696b76efe430 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -19,8 +19,9 @@ import pytest from parameterized import parameterized +from sqlalchemy.orm import contains_eager -from airflow.models import DagBag, DagRun, SlaMiss, TaskInstance +from airflow.models import DagRun, SlaMiss, TaskInstance from airflow.security import permissions from airflow.utils.platform import getuser from airflow.utils.session import provide_session @@ -58,7 +59,7 @@ def configured_app(minimal_app_for_api): class TestTaskInstanceEndpoint: @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: + def setup_attrs(self, configured_app, dagbag) -> None: self.default_time = DEFAULT_DATETIME_1 self.ti_init = { "execution_date": self.default_time, @@ -77,15 +78,13 @@ def setup_attrs(self, configured_app) -> None: self.client = self.app.test_client() # type:ignore clear_db_runs() clear_db_sla_miss() - DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() - self.dagbag = DagBag(include_examples=True, read_dags_from_db=True) + self.dagbag = dagbag def create_task_instances( self, session, dag_id: str = "example_python_operator", update_extras: bool = True, - single_dag_run: bool = True, task_instances=None, dag_run_state=State.RUNNING, ): @@ -97,6 +96,10 @@ def create_task_instances( if task_instances is not None: counter = min(len(task_instances), counter) + run_id = "TEST_DAG_RUN_ID" + execution_date = self.ti_init.pop("execution_date", self.default_time) + dr = None + for i in range(counter): if task_instances is None: pass @@ -104,31 +107,28 @@ def create_task_instances( self.ti_extras.update(task_instances[i]) else: self.ti_init.update(task_instances[i]) - ti = TaskInstance(task=tasks[i], **self.ti_init) - for key, value in self.ti_extras.items(): - setattr(ti, key, value) - session.add(ti) + if "execution_date" in self.ti_init: + run_id = f"TEST_DAG_RUN_ID_{i}" + execution_date = self.ti_init.pop("execution_date") + dr = None - if single_dag_run is False: + if not dr: dr = DagRun( + run_id=run_id, dag_id=dag_id, - run_id=f"TEST_DAG_RUN_ID_{i}", - execution_date=self.ti_init["execution_date"], - run_type=DagRunType.MANUAL.value, + execution_date=execution_date, + run_type=DagRunType.MANUAL, state=dag_run_state, ) session.add(dr) + ti = TaskInstance(task=tasks[i], **self.ti_init) + ti.dag_run = dr + + for key, value in self.ti_extras.items(): + setattr(ti, key, value) + session.add(ti) - if single_dag_run: - dr = DagRun( - dag_id=dag_id, - run_id="TEST_DAG_RUN_ID", - execution_date=self.default_time, - run_type=DagRunType.MANUAL.value, - state=dag_run_state, - ) - session.add(dr) session.commit() @@ -274,7 +274,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ], False, ( - "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/" + "/api/v1/dags/example_python_operator/dagRuns/~/" f"taskInstances?execution_date_lte={DEFAULT_DATETIME_STR_1}" ), 1, @@ -288,25 +288,11 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ], True, ( - "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances" + "/api/v1/dags/example_python_operator/dagRuns/~/taskInstances" f"?start_date_gte={DEFAULT_DATETIME_STR_1}&start_date_lte={DEFAULT_DATETIME_STR_2}" ), 2, ), - ( - "test start date filter with ~", - [ - {"start_date": DEFAULT_DATETIME_1}, - {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, - {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, - ], - True, - ( - "/api/v1/dags/~/dagRuns/~/taskInstances?start_date_gte" - f"={DEFAULT_DATETIME_STR_1}&start_date_lte={DEFAULT_DATETIME_STR_2}" - ), - 2, - ), ( "test end date filter", [ @@ -316,25 +302,11 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ], True, ( - "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances?" + "/api/v1/dags/example_python_operator/dagRuns/~/taskInstances?" f"end_date_gte={DEFAULT_DATETIME_STR_1}&end_date_lte={DEFAULT_DATETIME_STR_2}" ), 2, ), - ( - "test end date filter ~", - [ - {"end_date": DEFAULT_DATETIME_1}, - {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, - {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, - ], - True, - ( - "/api/v1/dags/~/dagRuns/~/taskInstances?end_date_gte" - f"={DEFAULT_DATETIME_STR_1}&end_date_lte={DEFAULT_DATETIME_STR_2}" - ), - 2, - ), ( "test duration filter", [ @@ -428,6 +400,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ) @provide_session def test_should_respond_200(self, _, task_instances, update_extras, url, expected_ti, session): + self.create_task_instances( session, update_extras=update_extras, @@ -476,7 +449,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): {"queue": "test_queue_3"}, ], True, - True, {"queue": ["test_queue_1", "test_queue_2"]}, 2, ), @@ -488,7 +460,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): {"pool": "test_pool_3"}, ], True, - True, {"pool": ["test_pool_1", "test_pool_2"]}, 2, ), @@ -500,7 +471,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): {"state": State.SUCCESS}, ], False, - True, {"state": ["running", "queued"]}, 2, ), @@ -512,7 +482,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): {"duration": 200}, ], True, - True, {"duration_gte": 100, "duration_lte": 200}, 3, ), @@ -524,7 +493,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, ], True, - True, { "end_date_gte": DEFAULT_DATETIME_STR_1, "end_date_lte": DEFAULT_DATETIME_STR_2, @@ -539,7 +507,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, ], True, - True, { "start_date_gte": DEFAULT_DATETIME_STR_1, "start_date_lte": DEFAULT_DATETIME_STR_2, @@ -557,7 +524,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): {"execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=5)}, ], False, - True, { "execution_date_gte": DEFAULT_DATETIME_1, "execution_date_lte": (DEFAULT_DATETIME_1 + dt.timedelta(days=2)), @@ -567,14 +533,11 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): ] ) @provide_session - def test_should_respond_200( - self, _, task_instances, update_extras, single_dag_run, payload, expected_ti_count, session - ): + def test_should_respond_200(self, _, task_instances, update_extras, payload, expected_ti_count, session): self.create_task_instances( session, update_extras=update_extras, task_instances=task_instances, - single_dag_run=single_dag_run, ) response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", @@ -593,7 +556,6 @@ def test_should_respond_200( {"task": "test_1"}, {"task": "test_2"}, ], - True, {"dag_ids": ["latest_only"]}, 2, ), @@ -601,7 +563,7 @@ def test_should_respond_200( ) @provide_session def test_should_respond_200_when_task_instance_properties_are_none( - self, _, task_instances, single_dag_run, payload, expected_ti_count, session + self, _, task_instances, payload, expected_ti_count, session ): self.ti_extras.update( { @@ -614,7 +576,6 @@ def test_should_respond_200_when_task_instance_properties_are_none( session, dag_id="latest_only", task_instances=task_instances, - single_dag_run=single_dag_run, ) response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", @@ -878,7 +839,6 @@ def test_should_respond_200( dag_id=main_dag, task_instances=task_instances, update_extras=False, - single_dag_run=False, ) self.app.dag_bag.sync_to_db() response = self.client.post( @@ -921,7 +881,6 @@ def test_should_respond_200_with_reset_dag_run(self, session): self.create_task_instances( session, dag_id=dag_id, - single_dag_run=False, task_instances=task_instances, update_extras=False, dag_run_state=State.FAILED, @@ -1020,7 +979,6 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se dag_id="example_python_operator", task_instances=task_instances, update_extras=False, - single_dag_run=False, ) self.app.dag_bag.sync_to_db() response = self.client.post( @@ -1037,7 +995,11 @@ class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint): def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, session): self.create_task_instances(session) mock_set_task_instance_state.return_value = ( - session.query(TaskInstance).filter(TaskInstance.task_id == "print_the_context").all() + session.query(TaskInstance) + .join(TaskInstance.dag_run) + .options(contains_eager(TaskInstance.dag_run)) + .filter(TaskInstance.task_id == "print_the_context") + .all() ) response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", @@ -1074,6 +1036,7 @@ def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, sessi state='failed', task_id='print_the_context', upstream=True, + session=session, ) def test_should_raises_401_unauthenticated(self): diff --git a/tests/api_connexion/schemas/test_dag_run_schema.py b/tests/api_connexion/schemas/test_dag_run_schema.py index b5333f1f1f0c7..ba5acae75f70c 100644 --- a/tests/api_connexion/schemas/test_dag_run_schema.py +++ b/tests/api_connexion/schemas/test_dag_run_schema.py @@ -34,11 +34,14 @@ DEFAULT_TIME = "2020-06-09T13:59:56.336000+00:00" +SECOND_TIME = "2020-06-10T13:59:56.336000+00:00" + class TestDAGRunBase(unittest.TestCase): def setUp(self) -> None: clear_db_runs() self.default_time = DEFAULT_TIME + self.second_time = SECOND_TIME def tearDown(self) -> None: clear_db_runs() @@ -135,7 +138,7 @@ def test_serialize(self, session): dagrun_model_2 = DagRun( run_id="my-dag-run-2", state='running', - execution_date=timezone.parse(self.default_time), + execution_date=timezone.parse(self.second_time), start_date=timezone.parse(self.default_time), run_type=DagRunType.MANUAL.value, ) @@ -162,8 +165,8 @@ def test_serialize(self, session): "dag_run_id": "my-dag-run-2", "end_date": None, "state": "running", - "execution_date": self.default_time, - "logical_date": self.default_time, + "execution_date": self.second_time, + "logical_date": self.second_time, "external_trigger": True, "start_date": self.default_time, "conf": {}, diff --git a/tests/api_connexion/schemas/test_event_log_schema.py b/tests/api_connexion/schemas/test_event_log_schema.py index 597ecc71b61f2..4517ecb759af6 100644 --- a/tests/api_connexion/schemas/test_event_log_schema.py +++ b/tests/api_connexion/schemas/test_event_log_schema.py @@ -15,72 +15,58 @@ # specific language governing permissions and limitations # under the License. -import unittest +import pytest -from airflow import DAG from airflow.api_connexion.schemas.event_log_schema import ( EventLogCollection, event_log_collection_schema, event_log_schema, ) -from airflow.models import Log, TaskInstance -from airflow.operators.dummy import DummyOperator +from airflow.models import Log from airflow.utils import timezone -from airflow.utils.session import create_session, provide_session -class TestEventLogSchemaBase(unittest.TestCase): - def setUp(self) -> None: - with create_session() as session: - session.query(Log).delete() - self.default_time = "2020-06-09T13:00:00+00:00" - self.default_time2 = '2020-06-11T07:00:00+00:00' +@pytest.fixture +def task_instance(session, create_task_instance, request): + return create_task_instance( + session=session, + dag_id="TEST_DAG_ID", + task_id="TEST_TASK_ID", + execution_date=request.instance.default_time, + ) - def tearDown(self) -> None: - with create_session() as session: - session.query(Log).delete() - def _create_task_instance(self): - with DAG( - 'TEST_DAG_ID', - start_date=timezone.parse(self.default_time), - end_date=timezone.parse(self.default_time), - ): - op1 = DummyOperator(task_id="TEST_TASK_ID", owner="airflow") - return TaskInstance(task=op1, execution_date=timezone.parse(self.default_time)) +class TestEventLogSchemaBase: + @pytest.fixture(autouse=True) + def set_attrs(self): + self.default_time = timezone.parse("2020-06-09T13:00:00+00:00") + self.default_time2 = timezone.parse('2020-06-11T07:00:00+00:00') class TestEventLogSchema(TestEventLogSchemaBase): - @provide_session - def test_serialize(self, session): - event_log_model = Log(event="TEST_EVENT", task_instance=self._create_task_instance()) - session.add(event_log_model) - session.commit() - event_log_model.dttm = timezone.parse(self.default_time) - log_model = session.query(Log).first() - deserialized_log = event_log_schema.dump(log_model) + def test_serialize(self, task_instance): + event_log_model = Log(event="TEST_EVENT", task_instance=task_instance) + event_log_model.dttm = self.default_time + deserialized_log = event_log_schema.dump(event_log_model) assert deserialized_log == { "event_log_id": event_log_model.id, "event": "TEST_EVENT", "dag_id": "TEST_DAG_ID", "task_id": "TEST_TASK_ID", - "execution_date": self.default_time, + "execution_date": self.default_time.isoformat(), "owner": 'airflow', - "when": self.default_time, + "when": self.default_time.isoformat(), "extra": None, } class TestEventLogCollection(TestEventLogSchemaBase): - @provide_session - def test_serialize(self, session): - event_log_model_1 = Log(event="TEST_EVENT_1", task_instance=self._create_task_instance()) - event_log_model_2 = Log(event="TEST_EVENT_2", task_instance=self._create_task_instance()) + def test_serialize(self, task_instance): + event_log_model_1 = Log(event="TEST_EVENT_1", task_instance=task_instance) + event_log_model_2 = Log(event="TEST_EVENT_2", task_instance=task_instance) event_logs = [event_log_model_1, event_log_model_2] - session.add_all(event_logs) - session.commit() - event_log_model_1.dttm = timezone.parse(self.default_time) - event_log_model_2.dttm = timezone.parse(self.default_time2) + event_log_model_1.dttm = self.default_time + event_log_model_2.dttm = self.default_time2 instance = EventLogCollection(event_logs=event_logs, total_entries=2) deserialized_event_logs = event_log_collection_schema.dump(instance) assert deserialized_event_logs == { @@ -90,9 +76,9 @@ def test_serialize(self, session): "event": "TEST_EVENT_1", "dag_id": "TEST_DAG_ID", "task_id": "TEST_TASK_ID", - "execution_date": self.default_time, + "execution_date": self.default_time.isoformat(), "owner": 'airflow', - "when": self.default_time, + "when": self.default_time.isoformat(), "extra": None, }, { @@ -100,9 +86,9 @@ def test_serialize(self, session): "event": "TEST_EVENT_2", "dag_id": "TEST_DAG_ID", "task_id": "TEST_TASK_ID", - "execution_date": self.default_time, + "execution_date": self.default_time.isoformat(), "owner": 'airflow', - "when": self.default_time2, + "when": self.default_time2.isoformat(), "extra": None, }, ], diff --git a/tests/api_connexion/schemas/test_task_instance_schema.py b/tests/api_connexion/schemas/test_task_instance_schema.py index 73895ae0f95fd..883d9367a204b 100644 --- a/tests/api_connexion/schemas/test_task_instance_schema.py +++ b/tests/api_connexion/schemas/test_task_instance_schema.py @@ -27,25 +27,29 @@ set_task_instance_state_form, task_instance_schema, ) -from airflow.models import DAG, SlaMiss, TaskInstance as TI +from airflow.models import SlaMiss, TaskInstance as TI from airflow.operators.dummy import DummyOperator from airflow.utils.platform import getuser -from airflow.utils.session import create_session, provide_session from airflow.utils.state import State from airflow.utils.timezone import datetime -class TestTaskInstanceSchema(unittest.TestCase): - def setUp(self): +class TestTaskInstanceSchema: + @pytest.fixture(autouse=True) + def set_attrs(self, session, dag_maker): self.default_time = datetime(2020, 1, 1) - with DAG(dag_id="TEST_DAG_ID"): + with dag_maker(dag_id="TEST_DAG_ID", session=session): self.task = DummyOperator(task_id="TEST_TASK_ID", start_date=self.default_time) + self.dr = dag_maker.create_dagrun(execution_date=self.default_time) + session.flush() + self.default_ti_init = { - "execution_date": self.default_time, + "run_id": None, "state": State.RUNNING, } self.default_ti_extras = { + "dag_run": self.dr, "start_date": self.default_time + dt.timedelta(days=1), "end_date": self.default_time + dt.timedelta(days=2), "pid": 100, @@ -54,18 +58,14 @@ def setUp(self): "queue": "default_queue", } - def tearDown(self): - with create_session() as session: - session.query(TI).delete() - session.query(SlaMiss).delete() + yield + + session.rollback() - @provide_session def test_task_instance_schema_without_sla(self, session): ti = TI(task=self.task, **self.default_ti_init) for key, value in self.default_ti_extras.items(): setattr(ti, key, value) - session.add(ti) - session.commit() serialized_ti = task_instance_schema.dump((ti, None)) expected_json = { "dag_id": "TEST_DAG_ID", @@ -91,19 +91,17 @@ def test_task_instance_schema_without_sla(self, session): } assert serialized_ti == expected_json - @provide_session def test_task_instance_schema_with_sla(self, session): - ti = TI(task=self.task, **self.default_ti_init) - for key, value in self.default_ti_extras.items(): - setattr(ti, key, value) sla_miss = SlaMiss( task_id="TEST_TASK_ID", dag_id="TEST_DAG_ID", execution_date=self.default_time, ) - session.add(ti) session.add(sla_miss) - session.commit() + session.flush() + ti = TI(task=self.task, **self.default_ti_init) + for key, value in self.default_ti_extras.items(): + setattr(ti, key, value) serialized_ti = task_instance_schema.dump((ti, sla_miss)) expected_json = { "dag_id": "TEST_DAG_ID", @@ -177,19 +175,17 @@ def test_validation_error(self, payload): clear_task_instance_form.load(payload) -class TestSetTaskInstanceStateFormSchema(unittest.TestCase): - def setUp(self) -> None: - super().setUp() - self.current_input = { - "dry_run": True, - "task_id": "print_the_context", - "execution_date": "2020-01-01T00:00:00+00:00", - "include_upstream": True, - "include_downstream": True, - "include_future": True, - "include_past": True, - "new_state": "failed", - } +class TestSetTaskInstanceStateFormSchema: + current_input = { + "dry_run": True, + "task_id": "print_the_context", + "execution_date": "2020-01-01T00:00:00+00:00", + "include_upstream": True, + "include_downstream": True, + "include_future": True, + "include_past": True, + "new_state": "failed", + } def test_success(self): result = set_task_instance_state_form.load(self.current_input) diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index fd8404a0598b8..4bd7013f93e9a 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -39,7 +39,7 @@ dag_folder_path = '/'.join(os.path.realpath(__file__).split('/')[:-1]) -DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1)) +DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1), timezone=timezone.utc) TEST_DAG_FOLDER = os.path.join(os.path.dirname(dag_folder_path), 'dags') TEST_DAG_ID = 'unit_tests' @@ -357,7 +357,7 @@ def test_cli_list_dags(self): assert "airflow" in out assert "paused" in out assert "airflow/example_dags/example_complex.py" in out - assert "False" in out + assert "- dag_id:" in out def test_cli_list_dag_runs(self): dag_command.dag_trigger( diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index 9231e62f39c69..983416e0d36c0 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -23,7 +23,7 @@ import re import unittest from contextlib import redirect_stdout -from datetime import datetime, timedelta +from datetime import datetime from unittest import mock import pytest @@ -32,17 +32,17 @@ from airflow.cli import cli_parser from airflow.cli.commands import task_command from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, DagRunNotFound from airflow.models import DagBag, DagRun, TaskInstance from airflow.utils import timezone -from airflow.utils.cli import get_dag +from airflow.utils.dates import days_ago from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.types import DagRunType from tests.test_utils.config import conf_vars -from tests.test_utils.db import clear_db_pools, clear_db_runs +from tests.test_utils.db import clear_db_runs -DEFAULT_DATE = timezone.make_aware(datetime(2016, 1, 1)) +DEFAULT_DATE = days_ago(1) ROOT_FOLDER = os.path.realpath( os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir) ) @@ -58,13 +58,22 @@ def reset(dag_id): # TODO: Check if tests needs side effects - locally there's missing DAG class TestCliTasks(unittest.TestCase): + run_id = 'TEST_RUN_ID' + dag_id = 'example_python_operator' + @classmethod def setUpClass(cls): cls.dagbag = DagBag(include_examples=True) cls.parser = cli_parser.get_parser() clear_db_runs() - def tearDown(self) -> None: + cls.dag = cls.dagbag.get_dag(cls.dag_id) + cls.dag_run = cls.dag.create_dagrun( + state=State.NONE, run_id=cls.run_id, run_type=DagRunType.MANUAL, execution_date=DEFAULT_DATE + ) + + @classmethod + def tearDownClass(cls) -> None: clear_db_runs() def test_cli_list_tasks(self): @@ -89,76 +98,33 @@ def test_test(self): def test_test_with_existing_dag_run(self): """Test the `airflow test` command""" - dag_id = 'example_python_operator' - run_id = 'TEST_RUN_ID' task_id = 'print_the_context' - dag = self.dagbag.get_dag(dag_id) - - dag.create_dagrun(state=State.NONE, run_id=run_id, run_type=DagRunType.MANUAL, external_trigger=True) - args = self.parser.parse_args(["tasks", "test", dag_id, task_id, run_id]) + args = self.parser.parse_args(["tasks", "test", self.dag_id, task_id, DEFAULT_DATE.isoformat()]) with redirect_stdout(io.StringIO()) as stdout: task_command.task_test(args) # Check that prints, and log messages, are shown - assert f"Marking task as SUCCESS. dag_id={dag_id}, task_id={task_id}" in stdout.getvalue() - - @mock.patch("airflow.cli.commands.task_command.LocalTaskJob") - def test_run_naive_taskinstance(self, mock_local_job): - """ - Test that we can run naive (non-localized) task instances - """ - naive_date = datetime(2016, 1, 1) - dag_id = 'test_run_ignores_all_dependencies' - - dag = self.dagbag.get_dag('test_run_ignores_all_dependencies') - - task0_id = 'test_run_dependent_task' - args0 = [ - 'tasks', - 'run', - '--ignore-all-dependencies', - '--local', - dag_id, - task0_id, - naive_date.isoformat(), - ] - - task_command.task_run(self.parser.parse_args(args0), dag=dag) - mock_local_job.assert_called_once_with( - task_instance=mock.ANY, - mark_success=False, - ignore_all_deps=True, - ignore_depends_on_past=False, - ignore_task_deps=False, - ignore_ti_state=False, - pickle_id=None, - pool=None, - ) + assert f"Marking task as SUCCESS. dag_id={self.dag_id}, task_id={task_id}" in stdout.getvalue() @mock.patch("airflow.cli.commands.task_command.LocalTaskJob") def test_run_with_existing_dag_run_id(self, mock_local_job): """ Test that we can run with existing dag_run_id """ - dag_id = 'test_run_ignores_all_dependencies' - - dag = self.dagbag.get_dag(dag_id) - task0_id = 'test_run_dependent_task' - run_id = 'TEST_RUN_ID' - dag.create_dagrun(state=State.NONE, run_id=run_id, run_type=DagRunType.MANUAL, external_trigger=True) + task0_id = self.dag.task_ids[0] args0 = [ 'tasks', 'run', '--ignore-all-dependencies', '--local', - dag_id, + self.dag_id, task0_id, - run_id, + self.run_id, ] - task_command.task_run(self.parser.parse_args(args0), dag=dag) + task_command.task_run(self.parser.parse_args(args0), dag=self.dag) mock_local_job.assert_called_once_with( task_instance=mock.ANY, mark_success=False, @@ -188,21 +154,8 @@ def test_run_raises_when_theres_no_dagrun(self, mock_local_job): task0_id, run_id, ] - with self.assertRaises(AirflowException) as err: + with self.assertRaises(DagRunNotFound): task_command.task_run(self.parser.parse_args(args0), dag=dag) - assert str(err.exception) == f"DagRun with run_id: {run_id} not found" - - def test_cli_test(self): - task_command.task_test( - self.parser.parse_args( - ['tasks', 'test', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat()] - ) - ) - task_command.task_test( - self.parser.parse_args( - ['tasks', 'test', 'example_bash_operator', 'runme_0', '--dry-run', DEFAULT_DATE.isoformat()] - ) - ) def test_cli_test_with_params(self): task_command.task_test( @@ -251,13 +204,6 @@ def test_cli_test_with_env_vars(self): assert 'foo=bar' in output assert 'AIRFLOW_TEST_MODE=True' in output - def test_cli_run(self): - task_command.task_run( - self.parser.parse_args( - ['tasks', 'run', 'example_bash_operator', 'runme_0', '--local', DEFAULT_DATE.isoformat()] - ) - ) - @parameterized.expand( [ ("--ignore-all-dependencies",), @@ -307,7 +253,7 @@ def test_task_render(self): """ with redirect_stdout(io.StringIO()) as stdout: task_command.task_render( - self.parser.parse_args(['tasks', 'render', 'tutorial', 'templated', DEFAULT_DATE.isoformat()]) + self.parser.parse_args(['tasks', 'render', 'tutorial', 'templated', '2016-01-01']) ) output = stdout.getvalue() @@ -326,7 +272,6 @@ def test_cli_run_when_pickle_and_dag_cli_method_selected(self): AirflowException, match=re.escape("You cannot use the --pickle option when using DAG.cli() method."), ): - dag = self.dagbag.get_dag('test_run_ignores_all_dependencies') task_command.task_run( self.parser.parse_args( [ @@ -339,13 +284,13 @@ def test_cli_run_when_pickle_and_dag_cli_method_selected(self): pickle_id, ] ), - dag, + self.dag, ) def test_task_state(self): task_command.task_state( self.parser.parse_args( - ['tasks', 'state', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat()] + ['tasks', 'state', self.dag_id, 'print_the_context', DEFAULT_DATE.isoformat()] ) ) @@ -395,7 +340,7 @@ def test_task_states_for_dag_run_when_dag_run_not_exists(self): """ task_states_for_dag_run should return an AirflowException when invalid dag id is passed """ - with pytest.raises(AirflowException, match="DagRun does not exist."): + with pytest.raises(DagRunNotFound): default_date2 = timezone.make_aware(datetime(2016, 1, 9)) task_command.task_states_for_dag_run( self.parser.parse_args( @@ -426,31 +371,6 @@ def test_parentdag_downstream_clear(self): ) task_command.task_clear(args) - @pytest.mark.quarantined - def test_local_run(self): - args = self.parser.parse_args( - [ - 'tasks', - 'run', - 'example_python_operator', - 'print_the_context', - '2018-04-27T08:39:51.298439+00:00', - '--interactive', - '--subdir', - '/root/dags/example_python_operator.py', - ] - ) - - dag = get_dag(args.subdir, args.dag_id) - reset(dag.dag_id) - - task_command.task_run(args) - task = dag.get_task(task_id=args.task_id) - ti = TaskInstance(task, args.execution_date) - ti.refresh_from_db() - state = ti.current_state() - assert state == State.SUCCESS - # For this test memory spins out of control on Python 3.6. TODO(potiuk): FIXME") @pytest.mark.quarantined @@ -650,66 +570,3 @@ def task_inner(*args, **kwargs): assert captured.output == ["WARNING:foo.bar:not redirected"] settings.DONOT_MODIFY_HANDLERS = old_value - - -class TestCliTaskBackfill(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.dagbag = DagBag(include_examples=True) - - def setUp(self): - clear_db_runs() - clear_db_pools() - - self.parser = cli_parser.get_parser() - - def test_run_ignores_all_dependencies(self): - """ - Test that run respects ignore_all_dependencies - """ - dag_id = 'test_run_ignores_all_dependencies' - - dag = self.dagbag.get_dag('test_run_ignores_all_dependencies') - dag.clear() - - task0_id = 'test_run_dependent_task' - args0 = ['tasks', 'run', '--ignore-all-dependencies', dag_id, task0_id, DEFAULT_DATE.isoformat()] - task_command.task_run(self.parser.parse_args(args0)) - ti_dependent0 = TaskInstance(task=dag.get_task(task0_id), execution_date=DEFAULT_DATE) - - ti_dependent0.refresh_from_db() - assert ti_dependent0.state == State.FAILED - - task1_id = 'test_run_dependency_task' - args1 = [ - 'tasks', - 'run', - '--ignore-all-dependencies', - dag_id, - task1_id, - (DEFAULT_DATE + timedelta(days=1)).isoformat(), - ] - task_command.task_run(self.parser.parse_args(args1)) - - ti_dependency = TaskInstance( - task=dag.get_task(task1_id), execution_date=DEFAULT_DATE + timedelta(days=1) - ) - ti_dependency.refresh_from_db() - assert ti_dependency.state == State.FAILED - - task2_id = 'test_run_dependent_task' - args2 = [ - 'tasks', - 'run', - '--ignore-all-dependencies', - dag_id, - task2_id, - (DEFAULT_DATE + timedelta(days=1)).isoformat(), - ] - task_command.task_run(self.parser.parse_args(args2)) - - ti_dependent = TaskInstance( - task=dag.get_task(task2_id), execution_date=DEFAULT_DATE + timedelta(days=1) - ) - ti_dependent.refresh_from_db() - assert ti_dependent.state == State.SUCCESS diff --git a/tests/conftest.py b/tests/conftest.py index 745ec9f4178b3..b18a472dc03c6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import json import os import subprocess import sys -from contextlib import ExitStack +from contextlib import ExitStack, suppress from datetime import datetime, timedelta import freezegun @@ -466,8 +467,8 @@ def dag_maker(request): want_serialized = False - # Allow changing default serialized behaviour with `@ptest.mark.need_serialized_dag` or - # `@ptest.mark.need_serialized_dag(False)` + # Allow changing default serialized behaviour with `@pytest.mark.need_serialized_dag` or + # `@pytest.mark.need_serialized_dag(False)` serialized_marker = request.node.get_closest_marker("need_serialized_dag") if serialized_marker: (want_serialized,) = serialized_marker.args or (True,) @@ -488,6 +489,15 @@ def __enter__(self): def _serialized_dag(self): return self.serialized_model.dag + def get_serialized_data(self): + try: + data = self.serialized_model.data + except AttributeError: + raise RuntimeError("DAG serialization not requested") + if isinstance(data, str): + return json.loads(data) + return data + def __exit__(self, type, value, traceback): from airflow.models import DagModel from airflow.models.serialized_dag import SerializedDagModel @@ -497,7 +507,7 @@ def __exit__(self, type, value, traceback): if type is not None: return - dag.clear() + dag.clear(session=self.session) dag.sync_to_db(self.session) self.dag_model = self.session.query(DagModel).get(dag.dag_id) @@ -511,6 +521,7 @@ def __exit__(self, type, value, traceback): self.dagbag.bag_dag(self.dag, self.dag) def create_dagrun(self, **kwargs): + from airflow.timetables.base import DataInterval from airflow.utils.state import State dag = self.dag @@ -525,7 +536,13 @@ def create_dagrun(self, **kwargs): # explicitly, or pass run_type for inference in dag.create_dagrun(). if "run_id" not in kwargs and "run_type" not in kwargs: kwargs["run_id"] = "test" + # Fill data_interval is not provided. + if not kwargs.get("data_interval"): + kwargs["data_interval"] = DataInterval.exact(kwargs["execution_date"]) + self.dag_run = dag.create_dagrun(**kwargs) + for ti in self.dag_run.task_instances: + ti.refresh_from_task(dag.get_task(ti.task_id)) return self.dag_run def __call__( @@ -587,7 +604,8 @@ def cleanup(self): yield factory finally: factory.cleanup() - del factory.session + with suppress(AttributeError): + del factory.session @pytest.fixture @@ -622,6 +640,7 @@ def create_dag( on_failure_callback=None, on_retry_callback=None, email=None, + with_dagrun=True, **kwargs, ): with dag_maker(dag_id, **kwargs) as dag: @@ -637,7 +656,69 @@ def create_dag( pool=pool, trigger_rule=trigger_rule, ) - dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + if with_dagrun: + dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) return dag, op return create_dag + + +@pytest.fixture +def create_task_instance(dag_maker, create_dummy_dag): + """ + Create a TaskInstance, and associated DB rows (DagRun, DagModel, etc) + + Uses ``create_dummy_dag`` to create the dag structure. + """ + + def maker(execution_date=None, dagrun_state=None, state=None, run_id='test', **kwargs): + if execution_date is None: + from airflow.utils import timezone + + execution_date = timezone.utcnow() + create_dummy_dag(with_dagrun=False, **kwargs) + + dr = dag_maker.create_dagrun(execution_date=execution_date, state=dagrun_state, run_id=run_id) + ti = dr.task_instances[0] + ti.state = state + + return ti + + return maker + + +@pytest.fixture() +def create_task_instance_of_operator(dag_maker): + def _create_task_instance( + operator_class, + *, + dag_id, + execution_date=None, + session=None, + **operator_kwargs, + ): + with dag_maker(dag_id=dag_id, session=session): + operator_class(**operator_kwargs) + (ti,) = dag_maker.create_dagrun(execution_date=execution_date).task_instances + return ti + + return _create_task_instance + + +@pytest.fixture() +def create_task_of_operator(dag_maker): + def _create_task_of_operator(operator_class, *, dag_id, session=None, **operator_kwargs): + with dag_maker(dag_id=dag_id, session=session): + task = operator_class(**operator_kwargs) + return task + + return _create_task_of_operator + + +@pytest.fixture +def session(): + from airflow.utils.session import create_session + + with create_session() as session: + yield session + session.rollback() diff --git a/tests/core/test_core.py b/tests/core/test_core.py index d8a568773a447..9131fbf6217e3 100644 --- a/tests/core/test_core.py +++ b/tests/core/test_core.py @@ -17,7 +17,6 @@ # under the License. import logging -import multiprocessing import os import signal from datetime import timedelta @@ -29,13 +28,13 @@ from airflow import settings from airflow.exceptions import AirflowException, AirflowTaskTimeout from airflow.hooks.base import BaseHook -from airflow.jobs.local_task_job import LocalTaskJob from airflow.models import DagBag, TaskFail, TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.operators.bash import BashOperator from airflow.operators.check_operator import CheckOperator, ValueCheckOperator from airflow.operators.dummy import DummyOperator from airflow.operators.python import PythonOperator +from airflow.utils.dates import days_ago from airflow.utils.state import State from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType @@ -108,10 +107,18 @@ def test_check_operators(self, dag_maker): captain_hook.run("drop table operator_test_table") - def test_clear_api(self): + def test_clear_api(self, session): task = self.dag_bash.tasks[0] - task.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, upstream=True, downstream=True) - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + + dr = self.dag_bash.create_dagrun( + run_type=DagRunType.MANUAL, + state=State.RUNNING, + execution_date=days_ago(1), + session=session, + ) + task.clear(start_date=dr.execution_date, end_date=dr.execution_date, upstream=True, downstream=True) + ti = dr.get_task_instance(task.task_id, session=session) + ti.task = task ti.are_dependents_done() def test_illegal_args(self, dag_maker): @@ -268,15 +275,16 @@ def __bool__(self): dag_maker.create_dagrun() op.resolve_template_files() - def test_task_get_template(self): - ti = TaskInstance(task=self.runme_0, execution_date=DEFAULT_DATE) - ti.dag = self.dag_bash - self.dag_bash.create_dagrun( + def test_task_get_template(self, session): + dr = self.dag_bash.create_dagrun( run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE, data_interval=(DEFAULT_DATE, DEFAULT_DATE + timedelta(days=1)), + session=session, ) + ti = TaskInstance(task=self.runme_0, run_id=dr.run_id) + ti.dag = self.dag_bash ti.run(ignore_ti_state=True) context = ti.get_template_context() @@ -314,64 +322,12 @@ def test_task_get_template(self): assert value == expected_value assert [str(m.message) for m in recorder] == [message] - def test_local_task_job(self): - TI = TaskInstance - ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE) - job = LocalTaskJob(task_instance=ti, ignore_ti_state=True) - job.run() - - def test_raw_job(self): - TI = TaskInstance - ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE) - ti.dag = self.dag_bash - self.dag_bash.create_dagrun( - run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE - ) - ti.run(ignore_ti_state=True) - def test_bad_trigger_rule(self, dag_maker): with pytest.raises(AirflowException): with dag_maker(): DummyOperator(task_id='test_bad_trigger', trigger_rule="non_existent") dag_maker.create_dagrun() - def test_terminate_task(self): - """If a task instance's db state get deleted, it should fail""" - from airflow.executors.sequential_executor import SequentialExecutor - - TI = TaskInstance - dag = self.dagbag.dags.get('test_utils') - task = dag.task_dict.get('sleeps_forever') - - ti = TI(task=task, execution_date=DEFAULT_DATE) - job = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) - - # Running task instance asynchronously - proc = multiprocessing.Process(target=job.run) - proc.start() - sleep(5) - settings.engine.dispose() - session = settings.Session() - ti.refresh_from_db(session=session) - # making sure it's actually running - assert State.RUNNING == ti.state - ti = ( - session.query(TI) - .filter_by(dag_id=task.dag_id, task_id=task.task_id, execution_date=DEFAULT_DATE) - .one() - ) - - # deleting the instance should result in a failure - session.delete(ti) - session.commit() - # waiting for the async task to finish - proc.join() - - # making sure that the task ended up as failed - ti.refresh_from_db(session=session) - assert State.FAILED == ti.state - session.close() - def test_task_fail_duration(self, dag_maker): """If a task fails, the duration should be recorded in TaskFail""" with dag_maker() as dag: @@ -382,6 +338,7 @@ def test_task_fail_duration(self, dag_maker): execution_timeout=timedelta(seconds=3), retry_delay=timedelta(seconds=0), ) + dag_maker.create_dagrun() session = settings.Session() try: op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) diff --git a/tests/core/test_sentry.py b/tests/core/test_sentry.py index 44a39d9126a68..5a5fa566da5dc 100644 --- a/tests/core/test_sentry.py +++ b/tests/core/test_sentry.py @@ -18,14 +18,12 @@ import datetime import importlib -import unittest -from unittest.mock import MagicMock, Mock +import pytest from freezegun import freeze_time from sentry_sdk import configure_scope -from airflow.models import TaskInstance -from airflow.settings import Session +from airflow.operators.python import PythonOperator from airflow.utils import timezone from airflow.utils.state import State from tests.test_utils.config import conf_vars @@ -33,7 +31,7 @@ EXECUTION_DATE = timezone.utcnow() DAG_ID = "test_dag" TASK_ID = "test_task" -OPERATOR = "test_operator" +OPERATOR = "PythonOperator" TRY_NUMBER = 1 STATE = State.SUCCESS TEST_SCOPE = { @@ -60,46 +58,49 @@ } -class TestSentryHook(unittest.TestCase): - @conf_vars({('sentry', 'sentry_on'): 'True'}) - def setUp(self): - from airflow import sentry +class TestSentryHook: + @pytest.fixture + def task_instance(self, dag_maker): + # Mock the Dag + with dag_maker(DAG_ID): + task = PythonOperator(task_id=TASK_ID, python_callable=int) - importlib.reload(sentry) - self.sentry = sentry.ConfiguredSentry() + dr = dag_maker.create_dagrun(execution_date=EXECUTION_DATE) + ti = dr.task_instances[0] + ti.state = STATE + ti.task = task + dag_maker.session.flush() - # Mock the Dag - self.dag = Mock(dag_id=DAG_ID, params=[]) - self.dag.task_ids = [TASK_ID] + yield ti - # Mock the task - self.task = Mock(dag=self.dag, dag_id=DAG_ID, task_id=TASK_ID, params=[], pool_slots=1) - self.task.__class__.__name__ = OPERATOR + dag_maker.session.rollback() - self.ti = TaskInstance(self.task, execution_date=EXECUTION_DATE) - self.ti.operator = OPERATOR - self.ti.state = STATE + @pytest.fixture + def sentry(self): + with conf_vars({('sentry', 'sentry_on'): 'True'}): + from airflow import sentry - self.dag.get_task_instances = MagicMock(return_value=[self.ti]) + importlib.reload(sentry) + yield sentry.Sentry - self.session = Session() + importlib.reload(sentry) - def test_add_tagging(self): + def test_add_tagging(self, sentry, task_instance): """ Test adding tags. """ - self.sentry.add_tagging(task_instance=self.ti) + sentry.add_tagging(task_instance=task_instance) with configure_scope() as scope: for key, value in scope._tags.items(): assert TEST_SCOPE[key] == value @freeze_time(CRUMB_DATE.isoformat()) - def test_add_breadcrumbs(self): + def test_add_breadcrumbs(self, sentry, task_instance): """ Test adding breadcrumbs. """ - self.sentry.add_tagging(task_instance=self.ti) - self.sentry.add_breadcrumbs(task_instance=self.ti, session=self.session) + sentry.add_tagging(task_instance=task_instance) + sentry.add_breadcrumbs(task_instance=task_instance) with configure_scope() as scope: test_crumb = scope._breadcrumbs.pop() diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py index ac2d4dceb3b1a..2c62939fe6f50 100644 --- a/tests/dag_processing/test_manager.py +++ b/tests/dag_processing/test_manager.py @@ -51,7 +51,8 @@ from airflow.utils.callback_requests import CallbackRequest, TaskCallbackRequest from airflow.utils.net import get_hostname from airflow.utils.session import create_session -from airflow.utils.state import State +from airflow.utils.state import DagRunState, State +from airflow.utils.types import DagRunType from tests.core.test_logging_config import SETTINGS_FILE_VALID, settings_context from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags @@ -110,6 +111,9 @@ class TestDagFileProcessorManager: def setup_method(self): clear_db_runs() + def teardown_class(self): + clear_db_runs() + def run_processor_manager_one_loop(self, manager, parent_pipe): if not manager._async_mode: parent_pipe.send(DagParsingSignal.AGENT_RUN_ONCE) @@ -432,16 +436,23 @@ def test_find_zombies(self): dag.sync_to_db() task = dag.get_task(task_id='run_this_first') - ti = TI(task, DEFAULT_DATE, State.RUNNING) + dag_run = dag.create_dagrun( + state=DagRunState.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + ) + + ti = TI(task, run_id=dag_run.run_id, state=State.RUNNING) local_job = LJ(ti) local_job.state = State.SHUTDOWN session.add(local_job) - session.commit() + session.flush() ti.job_id = local_job.id session.add(ti) - session.commit() + session.flush() manager._last_zombie_query_time = timezone.utcnow() - timedelta( seconds=manager._zombie_threshold_secs + 1 @@ -455,7 +466,7 @@ def test_find_zombies(self): assert isinstance(requests[0].simple_task_instance, SimpleTaskInstance) assert ti.dag_id == requests[0].simple_task_instance.dag_id assert ti.task_id == requests[0].simple_task_instance.task_id - assert ti.execution_date == requests[0].simple_task_instance.execution_date + assert ti.run_id == requests[0].simple_task_instance.run_id session.query(TI).delete() session.query(LJ).delete() @@ -475,19 +486,26 @@ def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_p session.query(LJ).delete() dag = dagbag.get_dag('test_example_bash_operator') dag.sync_to_db() + + dag_run = dag.create_dagrun( + state=DagRunState.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + ) task = dag.get_task(task_id='run_this_last') - ti = TI(task, DEFAULT_DATE, State.RUNNING) + ti = TI(task, run_id=dag_run.run_id, state=State.RUNNING) local_job = LJ(ti) local_job.state = State.SHUTDOWN session.add(local_job) - session.commit() + session.flush() # TODO: If there was an actual Relationship between TI and Job # we wouldn't need this extra commit session.add(ti) ti.job_id = local_job.id - session.commit() + session.flush() expected_failure_callback_requests = [ TaskCallbackRequest( diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index 54f0e9e2ef71e..43d4b86dc0c7f 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -19,7 +19,6 @@ import datetime import os -from tempfile import NamedTemporaryFile from unittest import mock from unittest.mock import MagicMock, patch from zipfile import ZipFile @@ -37,6 +36,7 @@ from airflow.utils.dates import days_ago from airflow.utils.session import create_session from airflow.utils.state import State +from airflow.utils.types import DagRunType from tests.test_utils.config import conf_vars, env_vars from tests.test_utils.db import ( clear_db_dags, @@ -82,9 +82,10 @@ def clean_db(): clear_db_jobs() clear_db_serialized_dags() - def setup_method(self): + def setup_class(self): self.clean_db() + def setup_method(self): # Speed up some tests by not running the tasks, just look at what we # enqueue! self.null_exec = MockExecutor() @@ -329,23 +330,24 @@ def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): with create_session() as session: session.query(TaskInstance).delete() dag = dagbag.get_dag('example_branch_operator') + dagrun = dag.create_dagrun( + state=State.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + ) task = dag.get_task(task_id='run_this_first') - - ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING) - + ti = TaskInstance(task, run_id=dagrun.run_id, state=State.RUNNING) session.add(ti) - session.commit() - requests = [ - TaskCallbackRequest( - full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message" - ) - ] - dag_file_processor.execute_callbacks(dagbag, requests) - mock_ti_handle_failure.assert_called_once_with( - error="Message", - test_mode=conf.getboolean('core', 'unit_test_mode'), - ) + requests = [ + TaskCallbackRequest(full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message") + ] + dag_file_processor.execute_callbacks(dagbag, requests) + mock_ti_handle_failure.assert_called_once_with( + error="Message", + test_mode=conf.getboolean('core', 'unit_test_mode'), + ) def test_failure_callbacks_should_not_drop_hostname(self): dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) @@ -355,36 +357,46 @@ def test_failure_callbacks_should_not_drop_hostname(self): with create_session() as session: dag = dagbag.get_dag('example_branch_operator') task = dag.get_task(task_id='run_this_first') - - ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING) + dagrun = dag.create_dagrun( + state=State.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + ) + ti = TaskInstance(task, run_id=dagrun.run_id, state=State.RUNNING) ti.hostname = "test_hostname" session.add(ti) + requests = [ + TaskCallbackRequest(full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message") + ] + dag_file_processor.execute_callbacks(dagbag, requests) + with create_session() as session: - requests = [ - TaskCallbackRequest( - full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message" - ) - ] - dag_file_processor.execute_callbacks(dagbag, requests) tis = session.query(TaskInstance) assert tis[0].hostname == "test_hostname" - def test_process_file_should_failure_callback(self): + def test_process_file_should_failure_callback(self, monkeypatch, tmp_path): + callback_file = tmp_path.joinpath("callback.txt") + callback_file.touch() + monkeypatch.setenv("AIRFLOW_CALLBACK_FILE", str(callback_file)) dag_file = os.path.join( os.path.dirname(os.path.realpath(__file__)), '../dags/test_on_failure_callback.py' ) dagbag = DagBag(dag_folder=dag_file, include_examples=False) dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - with create_session() as session, NamedTemporaryFile(delete=False) as callback_file: - session.query(TaskInstance).delete() - dag = dagbag.get_dag('test_om_failure_callback_dag') - task = dag.get_task(task_id='test_om_failure_callback_task') - - ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING) - session.add(ti) - session.commit() + dag = dagbag.get_dag('test_om_failure_callback_dag') + task = dag.get_task(task_id='test_om_failure_callback_task') + with create_session() as session: + dagrun = dag.create_dagrun( + state=State.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + ) + (ti,) = dagrun.task_instances + ti.refresh_from_task(task) requests = [ TaskCallbackRequest( @@ -393,14 +405,9 @@ def test_process_file_should_failure_callback(self): msg="Message", ) ] - callback_file.close() - - with mock.patch.dict("os.environ", {"AIRFLOW_CALLBACK_FILE": callback_file.name}): - dag_file_processor.process_file(dag_file, requests) - with open(callback_file.name) as callback_file2: - content = callback_file2.read() - assert "Callback fired" == content - os.remove(callback_file.name) + dag_file_processor.process_file(dag_file, requests, session=session) + + assert "Callback fired" == callback_file.read_text() @conf_vars({("core", "dagbag_import_error_tracebacks"): "False"}) def test_add_unparseable_file_before_sched_start_creates_import_error(self, tmpdir): diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py index 22518783ed671..561d551422fcc 100644 --- a/tests/executors/test_base_executor.py +++ b/tests/executors/test_base_executor.py @@ -15,60 +15,59 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -import unittest -from datetime import datetime, timedelta +from datetime import timedelta from unittest import mock from airflow.executors.base_executor import BaseExecutor from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import DAG -from airflow.models.taskinstance import TaskInstance, TaskInstanceKey +from airflow.models.taskinstance import TaskInstanceKey +from airflow.utils import timezone from airflow.utils.state import State -class TestBaseExecutor(unittest.TestCase): - def test_get_event_buffer(self): - executor = BaseExecutor() +def test_get_event_buffer(): + executor = BaseExecutor() + + date = timezone.utcnow() + try_number = 1 + key1 = TaskInstanceKey("my_dag1", "my_task1", date, try_number) + key2 = TaskInstanceKey("my_dag2", "my_task1", date, try_number) + key3 = TaskInstanceKey("my_dag2", "my_task2", date, try_number) + state = State.SUCCESS + executor.event_buffer[key1] = state, None + executor.event_buffer[key2] = state, None + executor.event_buffer[key3] = state, None + + assert len(executor.get_event_buffer(("my_dag1",))) == 1 + assert len(executor.get_event_buffer()) == 2 + assert len(executor.event_buffer) == 0 + - date = datetime.utcnow() - try_number = 1 - key1 = TaskInstanceKey("my_dag1", "my_task1", date, try_number) - key2 = TaskInstanceKey("my_dag2", "my_task1", date, try_number) - key3 = TaskInstanceKey("my_dag2", "my_task2", date, try_number) - state = State.SUCCESS - executor.event_buffer[key1] = state, None - executor.event_buffer[key2] = state, None - executor.event_buffer[key3] = state, None +@mock.patch('airflow.executors.base_executor.BaseExecutor.sync') +@mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks') +@mock.patch('airflow.executors.base_executor.Stats.gauge') +def test_gauge_executor_metrics(mock_stats_gauge, mock_trigger_tasks, mock_sync): + executor = BaseExecutor() + executor.heartbeat() + calls = [ + mock.call('executor.open_slots', mock.ANY), + mock.call('executor.queued_tasks', mock.ANY), + mock.call('executor.running_tasks', mock.ANY), + ] + mock_stats_gauge.assert_has_calls(calls) - assert len(executor.get_event_buffer(("my_dag1",))) == 1 - assert len(executor.get_event_buffer()) == 2 - assert len(executor.event_buffer) == 0 - @mock.patch('airflow.executors.base_executor.BaseExecutor.sync') - @mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks') - @mock.patch('airflow.executors.base_executor.Stats.gauge') - def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync): - executor = BaseExecutor() - executor.heartbeat() - calls = [ - mock.call('executor.open_slots', mock.ANY), - mock.call('executor.queued_tasks', mock.ANY), - mock.call('executor.running_tasks', mock.ANY), - ] - mock_stats_gauge.assert_has_calls(calls) +def test_try_adopt_task_instances(dag_maker): + date = timezone.utcnow() + start_date = date - timedelta(days=2) - def test_try_adopt_task_instances(self): - date = datetime.utcnow() - start_date = datetime.utcnow() - timedelta(days=2) + with dag_maker("test_try_adopt_task_instances"): + BaseOperator(task_id="task_1", start_date=start_date) + BaseOperator(task_id="task_2", start_date=start_date) + BaseOperator(task_id="task_3", start_date=start_date) - with DAG("test_try_adopt_task_instances"): - task_1 = BaseOperator(task_id="task_1", start_date=start_date) - task_2 = BaseOperator(task_id="task_2", start_date=start_date) - task_3 = BaseOperator(task_id="task_3", start_date=start_date) + dagrun = dag_maker.create_dagrun(execution_date=date) + tis = dagrun.task_instances - key1 = TaskInstance(task=task_1, execution_date=date) - key2 = TaskInstance(task=task_2, execution_date=date) - key3 = TaskInstance(task=task_3, execution_date=date) - tis = [key1, key2, key3] - assert BaseExecutor().try_adopt_task_instances(tis) == tis + assert [ti.task_id for ti in tis] == ["task_1", "task_2", "task_3"] + assert BaseExecutor().try_adopt_task_instances(tis) == tis diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index 23a338bb0e705..498c8ceca2194 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -184,7 +184,7 @@ def fake_execute_command(): 'command', 1, None, - SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.now())), + SimpleTaskInstance(ti=TaskInstance(task=task, run_id=None)), ) key = ('fail', 'fake_simple_ti', when, 0) executor.queued_tasks[key] = value_tuple @@ -217,7 +217,7 @@ def test_retry_on_error_sending_task(self): 'command', 1, None, - SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.now())), + SimpleTaskInstance(ti=TaskInstance(task=task, run_id=None)), ) key = ('fail', 'fake_simple_ti', when, 0) executor.queued_tasks[key] = value_tuple @@ -300,13 +300,12 @@ def test_command_validation(self, command, expected_exception): @pytest.mark.backend("mysql", "postgres") def test_try_adopt_task_instances_none(self): - date = datetime.utcnow() start_date = datetime.utcnow() - timedelta(days=2) with DAG("test_try_adopt_task_instances_none"): task_1 = BaseOperator(task_id="task_1", start_date=start_date) - key1 = TaskInstance(task=task_1, execution_date=date) + key1 = TaskInstance(task=task_1, run_id=None) tis = [key1] executor = celery_executor.CeleryExecutor() @@ -314,7 +313,6 @@ def test_try_adopt_task_instances_none(self): @pytest.mark.backend("mysql", "postgres") def test_try_adopt_task_instances(self): - exec_date = timezone.utcnow() - timedelta(minutes=2) start_date = timezone.utcnow() - timedelta(days=2) queued_dttm = timezone.utcnow() - timedelta(minutes=1) @@ -324,11 +322,11 @@ def test_try_adopt_task_instances(self): task_1 = BaseOperator(task_id="task_1", start_date=start_date) task_2 = BaseOperator(task_id="task_2", start_date=start_date) - ti1 = TaskInstance(task=task_1, execution_date=exec_date) + ti1 = TaskInstance(task=task_1, run_id=None) ti1.external_executor_id = '231' ti1.queued_dttm = queued_dttm ti1.state = State.QUEUED - ti2 = TaskInstance(task=task_2, execution_date=exec_date) + ti2 = TaskInstance(task=task_2, run_id=None) ti2.external_executor_id = '232' ti2.queued_dttm = queued_dttm ti2.state = State.QUEUED @@ -341,8 +339,8 @@ def test_try_adopt_task_instances(self): not_adopted_tis = executor.try_adopt_task_instances(tis) - key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, exec_date, try_number) - key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, exec_date, try_number) + key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, None, try_number) + key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, None, try_number) assert executor.running == {key_1, key_2} assert dict(executor.adopted_task_timeouts) == { key_1: queued_dttm + executor.task_adoption_timeout, @@ -353,7 +351,6 @@ def test_try_adopt_task_instances(self): @pytest.mark.backend("mysql", "postgres") def test_check_for_stalled_adopted_tasks(self): - exec_date = timezone.utcnow() - timedelta(minutes=40) start_date = timezone.utcnow() - timedelta(days=2) queued_dttm = timezone.utcnow() - timedelta(minutes=30) @@ -363,8 +360,8 @@ def test_check_for_stalled_adopted_tasks(self): task_1 = BaseOperator(task_id="task_1", start_date=start_date) task_2 = BaseOperator(task_id="task_2", start_date=start_date) - key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, exec_date, try_number) - key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, exec_date, try_number) + key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, "runid", try_number) + key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, "runid", try_number) executor = celery_executor.CeleryExecutor() executor.adopted_task_timeouts = { diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index e48d6ce9e3a81..025b95611f5b0 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -42,7 +42,7 @@ ) from airflow.kubernetes import pod_generator from airflow.kubernetes.kubernetes_helper_functions import annotations_to_key - from airflow.kubernetes.pod_generator import PodGenerator, datetime_to_label_safe_datestring + from airflow.kubernetes.pod_generator import PodGenerator from airflow.utils.state import State except ImportError: AirflowKubernetesScheduler = None # type: ignore @@ -226,7 +226,7 @@ def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watc # Execute a task while the Api Throws errors try_number = 1 kubernetes_executor.execute_async( - key=('dag', 'task', datetime.utcnow(), try_number), + key=('dag', 'task', 'run_id', try_number), queue=None, command=['airflow', 'tasks', 'run', 'true', 'some_parameter'], ) @@ -298,10 +298,8 @@ def test_pod_template_file_override_in_executor_config(self, mock_get_kube_clien assert executor.event_buffer == {} assert executor.task_queue.empty() - execution_date = datetime.utcnow() - executor.execute_async( - key=('dag', 'task', execution_date, 1), + key=('dag', 'task', 'run_id', 1), queue=None, command=['airflow', 'tasks', 'run', 'true', 'some_parameter'], executor_config={ @@ -333,7 +331,7 @@ def test_pod_template_file_override_in_executor_config(self, mock_get_kube_clien namespace="default", annotations={ 'dag_id': 'dag', - 'execution_date': execution_date.isoformat(), + 'run_id': 'run_id', 'task_id': 'task', 'try_number': '1', }, @@ -341,7 +339,7 @@ def test_pod_template_file_override_in_executor_config(self, mock_get_kube_clien 'airflow-worker': '5', 'airflow_version': mock.ANY, 'dag_id': 'dag', - 'execution_date': datetime_to_label_safe_datestring(execution_date), + 'run_id': 'run_id', 'kubernetes_executor': 'True', 'mylabel': 'foo', 'release': 'stable', @@ -370,7 +368,7 @@ def test_pod_template_file_override_in_executor_config(self, mock_get_kube_clien def test_change_state_running(self, mock_get_kube_client, mock_kubernetes_job_watcher): executor = self.kubernetes_executor executor.start() - key = ('dag_id', 'task_id', 'ex_time', 'try_number1') + key = ('dag_id', 'task_id', 'run_id', 'try_number1') executor._change_state(key, State.RUNNING, 'pod_id', 'default') assert executor.event_buffer[key][0] == State.RUNNING @@ -380,8 +378,7 @@ def test_change_state_running(self, mock_get_kube_client, mock_kubernetes_job_wa def test_change_state_success(self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher): executor = self.kubernetes_executor executor.start() - test_time = timezone.utcnow() - key = ('dag_id', 'task_id', test_time, 'try_number2') + key = ('dag_id', 'task_id', 'run_id', 'try_number2') executor._change_state(key, State.SUCCESS, 'pod_id', 'default') assert executor.event_buffer[key][0] == State.SUCCESS mock_delete_pod.assert_called_once_with('pod_id', 'default') @@ -396,8 +393,7 @@ def test_change_state_failed_no_deletion( executor.kube_config.delete_worker_pods = False executor.kube_config.delete_worker_pods_on_failure = False executor.start() - test_time = timezone.utcnow() - key = ('dag_id', 'task_id', test_time, 'try_number3') + key = ('dag_id', 'task_id', 'run_id', 'try_number3') executor._change_state(key, State.FAILED, 'pod_id', 'default') assert executor.event_buffer[key][0] == State.FAILED mock_delete_pod.assert_not_called() @@ -408,13 +404,12 @@ def test_change_state_failed_no_deletion( def test_change_state_skip_pod_deletion( self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher ): - test_time = timezone.utcnow() executor = self.kubernetes_executor executor.kube_config.delete_worker_pods = False executor.kube_config.delete_worker_pods_on_failure = False executor.start() - key = ('dag_id', 'task_id', test_time, 'try_number2') + key = ('dag_id', 'task_id', 'run_id', 'try_number2') executor._change_state(key, State.SUCCESS, 'pod_id', 'default') assert executor.event_buffer[key][0] == State.SUCCESS mock_delete_pod.assert_not_called() @@ -429,7 +424,7 @@ def test_change_state_failed_pod_deletion( executor.kube_config.delete_worker_pods_on_failure = True executor.start() - key = ('dag_id', 'task_id', 'ex_time', 'try_number2') + key = ('dag_id', 'task_id', 'run_id', 'try_number2') executor._change_state(key, State.FAILED, 'pod_id', 'test-namespace') assert executor.event_buffer[key][0] == State.FAILED mock_delete_pod.assert_called_once_with('pod_id', 'test-namespace') @@ -442,7 +437,7 @@ def test_try_adopt_task_instances(self, mock_adopt_completed_pods, mock_adopt_la ti_key = annotations_to_key( { 'dag_id': 'dag', - 'execution_date': datetime.utcnow().isoformat(), + 'run_id': 'run_id', 'task_id': 'task', 'try_number': '1', } @@ -525,7 +520,7 @@ def test_adopt_launched_task(self, mock_kube_client): executor.scheduler_job_id = "modified" annotations = { 'dag_id': 'dag', - 'execution_date': datetime.utcnow().isoformat(), + 'run_id': 'run_id', 'task_id': 'task', 'try_number': '1', } @@ -566,7 +561,7 @@ def test_not_adopt_unassigned_task(self, mock_kube_client): labels={"airflow-worker": "bar"}, annotations={ 'dag_id': 'dag', - 'execution_date': datetime.utcnow().isoformat(), + 'run_id': 'run_id', 'task_id': 'task', 'try_number': '1', }, @@ -680,8 +675,9 @@ def setUp(self): self.core_annotations = { "dag_id": "dag", "task_id": "task", - "execution_date": "dt", + "run_id": "run_id", "try_number": "1", + "execution_date": None, } self.pod = k8s.V1Pod( metadata=k8s.V1ObjectMeta( diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 31826a8205d5f..32ae7941d1a55 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -24,7 +24,6 @@ from unittest.mock import patch import pytest -import sqlalchemy from airflow import settings from airflow.cli import cli_parser @@ -191,7 +190,7 @@ def test_backfill_multi_dates(self): ("run_this_last", end_date), ] assert [ - ((dag.dag_id, task_id, when, 1), (State.SUCCESS, None)) + ((dag.dag_id, task_id, f'backfill__{when.isoformat()}', 1), (State.SUCCESS, None)) for (task_id, when) in expected_execution_order ] == executor.sorted_tasks @@ -268,7 +267,7 @@ def test_backfill_examples(self, dag_id, expected_execution_order): job.run() assert [ - ((dag_id, task_id, DEFAULT_DATE, 1), (State.SUCCESS, None)) + ((dag_id, task_id, f'backfill__{DEFAULT_DATE.isoformat()}', 1), (State.SUCCESS, None)) for task_id in expected_execution_order ] == executor.sorted_tasks @@ -707,13 +706,13 @@ def test_backfill_retry_always_failed_task(self, dag_maker): }, ) as dag: task1 = DummyOperator(task_id="task1") - dag_maker.create_dagrun() + dr = dag_maker.create_dagrun() executor = MockExecutor(parallelism=16) executor.mock_task_results[ - TaskInstanceKey(dag.dag_id, task1.task_id, DEFAULT_DATE, try_number=1) + TaskInstanceKey(dag.dag_id, task1.task_id, dr.run_id, try_number=1) ] = State.UP_FOR_RETRY - executor.mock_task_fail(dag.dag_id, task1.task_id, DEFAULT_DATE, try_number=2) + executor.mock_task_fail(dag.dag_id, task1.task_id, dr.run_id, try_number=2) job = BackfillJob( dag=dag, executor=executor, @@ -739,7 +738,8 @@ def test_backfill_ordered_concurrent_execute(self, dag_maker): op1.set_downstream(op3) op4.set_downstream(op5) op3.set_downstream(op4) - dag_maker.create_dagrun() + runid0 = f'backfill__{DEFAULT_DATE.isoformat()}' + dag_maker.create_dagrun(run_id=runid0) executor = MockExecutor(parallelism=16) job = BackfillJob( @@ -750,25 +750,24 @@ def test_backfill_ordered_concurrent_execute(self, dag_maker): ) job.run() - date0 = DEFAULT_DATE - date1 = date0 + datetime.timedelta(days=1) - date2 = date1 + datetime.timedelta(days=1) + runid1 = f'backfill__{(DEFAULT_DATE + datetime.timedelta(days=1)).isoformat()}' + runid2 = f'backfill__{(DEFAULT_DATE + datetime.timedelta(days=2)).isoformat()}' # test executor history keeps a list history = executor.history assert [sorted(item[-1].key[1:3] for item in batch) for batch in history] == [ [ - ('leave1', date0), - ('leave1', date1), - ('leave1', date2), - ('leave2', date0), - ('leave2', date1), - ('leave2', date2), + ('leave1', runid0), + ('leave1', runid1), + ('leave1', runid2), + ('leave2', runid0), + ('leave2', runid1), + ('leave2', runid2), ], - [('upstream_level_1', date0), ('upstream_level_1', date1), ('upstream_level_1', date2)], - [('upstream_level_2', date0), ('upstream_level_2', date1), ('upstream_level_2', date2)], - [('upstream_level_3', date0), ('upstream_level_3', date1), ('upstream_level_3', date2)], + [('upstream_level_1', runid0), ('upstream_level_1', runid1), ('upstream_level_1', runid2)], + [('upstream_level_2', runid0), ('upstream_level_2', runid1), ('upstream_level_2', runid2)], + [('upstream_level_3', runid0), ('upstream_level_3', runid1), ('upstream_level_3', runid2)], ] def test_backfill_pooled_tasks(self): @@ -1045,13 +1044,6 @@ def test_sub_set_subdag(self, dag_maker): job = BackfillJob(dag=sub_dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor) job.run() - with pytest.raises(sqlalchemy.orm.exc.NoResultFound): - dr.refresh_from_db() - # the run_id should have changed, so a refresh won't work - drs = DagRun.find(dag_id=dag.dag_id, execution_date=DEFAULT_DATE) - dr = drs[0] - - assert DagRun.generate_run_id(DagRunType.BACKFILL_JOB, DEFAULT_DATE) == dr.run_id for ti in dr.get_task_instances(): if ti.task_id == 'leave1' or ti.task_id == 'leave2': assert State.SUCCESS == ti.state @@ -1097,11 +1089,7 @@ def test_backfill_fill_blanks(self, dag_maker): with pytest.raises(AirflowException, match='Some task instances failed'): job.run() - with pytest.raises(sqlalchemy.orm.exc.NoResultFound): - dr.refresh_from_db() - # the run_id should have changed, so a refresh won't work - drs = DagRun.find(dag_id=dag.dag_id, execution_date=DEFAULT_DATE) - dr = drs[0] + dr.refresh_from_db() assert dr.state == State.FAILED @@ -1210,18 +1198,22 @@ def test_backfill_execute_subdag_with_removed_task(self): dag = self.dagbag.get_dag('example_subdag_operator') subdag = dag.get_task('section-1').subdag + session = settings.Session() executor = MockExecutor() job = BackfillJob( dag=subdag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor, donot_pickle=True ) + dr = DagRun( + dag_id=subdag.dag_id, execution_date=DEFAULT_DATE, run_id="test", run_type=DagRunType.BACKFILL_JOB + ) + session.add(dr) removed_task_ti = TI( - task=DummyOperator(task_id='removed_task'), execution_date=DEFAULT_DATE, state=State.REMOVED + task=DummyOperator(task_id='removed_task'), run_id=dr.run_id, state=State.REMOVED ) removed_task_ti.dag_id = subdag.dag_id + dr.task_instances.append(removed_task_ti) - session = settings.Session() - session.merge(removed_task_ti) session.commit() with timeout(seconds=30): @@ -1378,8 +1370,9 @@ def test_backfill_run_backwards(self): session = settings.Session() tis = ( session.query(TI) + .join(TI.dag_run) .filter(TI.dag_id == 'test_start_date_scheduling' and TI.task_id == 'dummy') - .order_by(TI.execution_date) + .order_by(DagRun.execution_date) .all() ) @@ -1397,7 +1390,7 @@ def test_reset_orphaned_tasks_with_orphans(self, dag_maker): states_to_reset = [State.QUEUED, State.SCHEDULED, State.NONE] tasks = [] - with dag_maker(dag_id=prefix, start_date=DEFAULT_DATE, schedule_interval="@daily") as dag: + with dag_maker(dag_id=prefix) as dag: for i in range(len(states)): task_id = f"{prefix}_task_{i}" task = DummyOperator(task_id=task_id) @@ -1452,7 +1445,7 @@ def test_reset_orphaned_tasks_with_orphans(self, dag_maker): for state, ti in zip(states, dr2_tis): assert state == ti.state - def test_reset_orphaned_tasks_specified_dagrun(self, dag_maker): + def test_reset_orphaned_tasks_specified_dagrun(self, session, dag_maker): """Try to reset when we specify a dagrun and ensure nothing else is.""" dag_id = 'test_reset_orphaned_tasks_specified_dagrun' task_id = dag_id + '_task' @@ -1460,14 +1453,14 @@ def test_reset_orphaned_tasks_specified_dagrun(self, dag_maker): dag_id=dag_id, start_date=DEFAULT_DATE, schedule_interval='@daily', + session=session, ) as dag: DummyOperator(task_id=task_id, dag=dag) job = BackfillJob(dag=dag) - session = settings.Session() # make two dagruns, only reset for one dr1 = dag_maker.create_dagrun(state=State.SUCCESS) - dr2 = dag.create_dagrun(run_id='test2', state=State.RUNNING) + dr2 = dag.create_dagrun(run_id='test2', state=State.RUNNING, session=session) ti1 = dr1.get_task_instances(session=session)[0] ti2 = dr2.get_task_instances(session=session)[0] ti1.state = State.SCHEDULED @@ -1477,7 +1470,7 @@ def test_reset_orphaned_tasks_specified_dagrun(self, dag_maker): session.merge(ti2) session.merge(dr1) session.merge(dr2) - session.commit() + session.flush() num_reset_tis = job.reset_state_for_orphaned_tasks(filter_by_dag_run=dr2, session=session) assert 1 == num_reset_tis diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 34b8526b75522..1574fd6e4a9b2 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -244,7 +244,7 @@ def test_heartbeat_failed_fast(self): dag = self.dagbag.get_dag(dag_id) task = dag.get_task(task_id) - dag.create_dagrun( + dr = dag.create_dagrun( run_id="test_heartbeat_failed_fast_run", state=State.RUNNING, execution_date=DEFAULT_DATE, @@ -252,9 +252,9 @@ def test_heartbeat_failed_fast(self): session=session, ) - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) - ti.refresh_from_db() - ti.state = State.RUNNING + ti = dr.task_instances[0] + ti.refresh_from_task(task) + ti.state = State.QUEUED ti.hostname = get_hostname() ti.pid = 1 session.commit() @@ -291,11 +291,12 @@ def task_function(ti): time.sleep(10) with dag_maker('test_mark_success'): - task1 = PythonOperator(task_id="task1", python_callable=task_function) - dag_maker.create_dagrun() + task = PythonOperator(task_id="task1", python_callable=task_function) + dr = dag_maker.create_dagrun() + + ti = dr.task_instances[0] + ti.refresh_from_task(task) - ti = TaskInstance(task=task1, execution_date=DEFAULT_DATE) - ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) def dummy_return_code(*args, **kwargs): @@ -335,7 +336,7 @@ def test_localtaskjob_double_trigger(self): session.merge(ti) session.commit() - ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti_run = TaskInstance(task=task, run_id=dr.run_id) ti_run.refresh_from_db() job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor()) with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_method: @@ -671,14 +672,14 @@ def test_fast_follow( dag_run = dag.create_dagrun(run_id='test_dagrun_fast_follow', state=State.RUNNING) - task_instance_a = TaskInstance(task_a, dag_run.execution_date, init_state['A']) + task_instance_a = TaskInstance(task_a, run_id=dag_run.run_id, state=init_state['A']) - task_instance_b = TaskInstance(task_b, dag_run.execution_date, init_state['B']) + task_instance_b = TaskInstance(task_b, run_id=dag_run.run_id, state=init_state['B']) - task_instance_c = TaskInstance(task_c, dag_run.execution_date, init_state['C']) + task_instance_c = TaskInstance(task_c, run_id=dag_run.run_id, state=init_state['C']) if 'D' in init_state: - task_instance_d = TaskInstance(task_d, dag_run.execution_date, init_state['D']) + task_instance_d = TaskInstance(task_d, run_id=dag_run.run_id, state=init_state['D']) session.merge(task_instance_d) session.merge(task_instance_a) @@ -731,8 +732,9 @@ def task_function(ti): retries=1, on_retry_callback=retry_callback, ) - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) - ti.refresh_from_db() + dr = dag_maker.create_dagrun() + ti = dr.task_instances[0] + ti.refresh_from_task(task) job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) settings.engine.dispose() with timeout(10): @@ -814,20 +816,20 @@ def clean_db_helper(): @pytest.mark.usefixtures("clean_db_helper") -class TestLocalTaskJobPerformance: - @pytest.mark.parametrize("return_codes", [[0], 9 * [None] + [0]]) # type: ignore - @mock.patch("airflow.jobs.local_task_job.get_task_runner") - def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes, dag_maker): - unique_prefix = str(uuid.uuid4()) - with dag_maker(dag_id=f'{unique_prefix}_test_number_of_queries'): - task = DummyOperator(task_id='test_state_succeeded1') +@pytest.mark.parametrize("return_codes", [[0], 9 * [None] + [0]]) +@mock.patch("airflow.jobs.local_task_job.get_task_runner") +def test_number_of_queries_single_loop(mock_get_task_runner, return_codes, dag_maker): + mock_get_task_runner.return_value.return_code.side_effects = return_codes - dag_maker.create_dagrun(run_id=unique_prefix, state=State.NONE) + unique_prefix = str(uuid.uuid4()) + with dag_maker(dag_id=f'{unique_prefix}_test_number_of_queries'): + task = DummyOperator(task_id='test_state_succeeded1') - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + dr = dag_maker.create_dagrun(run_id=unique_prefix, state=State.NONE) - mock_get_task_runner.return_value.return_code.side_effects = return_codes + ti = dr.task_instances[0] + ti.refresh_from_task(task) - job = LocalTaskJob(task_instance=ti, executor=MockExecutor()) - with assert_queries_count(18): - job.run() + job = LocalTaskJob(task_instance=ti, executor=MockExecutor()) + with assert_queries_count(25): + job.run() diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 0be44789b914c..dfcb67ec441fa 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -184,15 +184,14 @@ def test_no_orphan_process_will_be_left(self): @mock.patch('airflow.jobs.scheduler_job.Stats.incr') def test_process_executor_events(self, mock_stats_incr, mock_task_callback, dag_maker): dag_id = "test_process_executor_events" - dag_id2 = "test_process_executor_events_2" task_id_1 = 'dummy_task' with dag_maker(dag_id=dag_id, fileloc='/test_path1/'): task1 = DummyOperator(task_id=task_id_1) - with dag_maker(dag_id=dag_id2, fileloc='/test_path1/'): - DummyOperator(task_id=task_id_1) + ti1 = dag_maker.create_dagrun().get_task_instance(task1.task_id) mock_stats_incr.reset_mock() + executor = MockExecutor(do_update=False) task_callback = mock.MagicMock() mock_task_callback.return_value = task_callback @@ -201,7 +200,6 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback, dag_ session = settings.Session() - ti1 = TaskInstance(task1, DEFAULT_DATE) ti1.state = State.QUEUED session.merge(ti1) session.commit() @@ -215,7 +213,7 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback, dag_ full_filepath='/test_path1/', simple_task_instance=mock.ANY, msg='Executor reports task instance ' - ' ' + ' ' 'finished (failed) although the task says its queued. (Info: None) ' 'Was the task killed externally?', ) @@ -235,22 +233,23 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback, dag_ mock_stats_incr.assert_called_once_with('scheduler.tasks.killed_externally') def test_process_executor_events_uses_inmemory_try_number(self, dag_maker): - execution_date = DEFAULT_DATE dag_id = "dag_id" task_id = "task_id" try_number = 42 + with dag_maker(dag_id=dag_id): + DummyOperator(task_id=task_id) + + dr = dag_maker.create_dagrun() + executor = MagicMock() self.scheduler_job = SchedulerJob(executor=executor) self.scheduler_job.processor_agent = MagicMock() - event_buffer = {TaskInstanceKey(dag_id, task_id, execution_date, try_number): (State.SUCCESS, None)} + event_buffer = {TaskInstanceKey(dag_id, task_id, dr.run_id, try_number): (State.SUCCESS, None)} executor.get_event_buffer.return_value = event_buffer - with dag_maker(dag_id=dag_id): - task = DummyOperator(task_id=task_id) - with create_session() as session: - ti = TaskInstance(task, DEFAULT_DATE) + ti = dr.task_instances[0] ti.state = State.SUCCESS session.merge(ti) @@ -259,55 +258,25 @@ def test_process_executor_events_uses_inmemory_try_number(self, dag_maker): # task instance key assert event_buffer == {} - def test_execute_task_instances_is_paused_wont_execute(self, dag_maker): + def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker): dag_id = 'SchedulerJobTest.test_execute_task_instances_is_paused_wont_execute' task_id_1 = 'dummy_task' - with dag_maker(dag_id=dag_id) as dag: - task1 = DummyOperator(task_id=task_id_1) + with dag_maker(dag_id=dag_id, session=session) as dag: + DummyOperator(task_id=task_id_1) assert isinstance(dag, SerializedDAG) self.scheduler_job = SchedulerJob(subdir=os.devnull) - session = settings.Session() dr1 = dag_maker.create_dagrun(run_type=DagRunType.BACKFILL_JOB) - ti1 = TaskInstance(task1, DEFAULT_DATE) + (ti1,) = dr1.task_instances ti1.state = State.SCHEDULED - session.merge(ti1) - session.merge(dr1) - session.flush() self.scheduler_job._critical_section_execute_task_instances(session) - session.flush() - ti1.refresh_from_db() + ti1.refresh_from_db(session=session) assert State.SCHEDULED == ti1.state session.rollback() - def test_execute_task_instances_no_dagrun_task_will_execute(self, dag_maker): - """ - Tests that tasks without dagrun still get executed. - """ - dag_id = 'SchedulerJobTest.test_execute_task_instances_no_dagrun_task_will_execute' - task_id_1 = 'dummy_task' - - with dag_maker(dag_id=dag_id): - task1 = DummyOperator(task_id=task_id_1) - - self.scheduler_job = SchedulerJob(subdir=os.devnull) - session = settings.Session() - - ti1 = TaskInstance(task1, DEFAULT_DATE) - ti1.state = State.SCHEDULED - ti1.execution_date = ti1.execution_date + datetime.timedelta(days=1) - session.merge(ti1) - session.flush() - - self.scheduler_job._critical_section_execute_task_instances(session) - session.flush() - ti1.refresh_from_db() - assert State.QUEUED == ti1.state - session.rollback() - def test_execute_task_instances_backfill_tasks_wont_execute(self, dag_maker): """ Tests that backfill tasks won't get executed. @@ -323,11 +292,10 @@ def test_execute_task_instances_backfill_tasks_wont_execute(self, dag_maker): dr1 = dag_maker.create_dagrun(run_type=DagRunType.BACKFILL_JOB) - ti1 = TaskInstance(task1, dr1.execution_date) + ti1 = TaskInstance(task1, run_id=dr1.run_id) ti1.refresh_from_db() ti1.state = State.SCHEDULED session.merge(ti1) - session.merge(dr1) session.flush() assert dr1.is_backfill @@ -338,8 +306,8 @@ def test_execute_task_instances_backfill_tasks_wont_execute(self, dag_maker): assert State.SCHEDULED == ti1.state session.rollback() - def test_find_executable_task_instances_backfill_nodagrun(self, dag_maker): - dag_id = 'SchedulerJobTest.test_find_executable_task_instances_backfill_nodagrun' + def test_find_executable_task_instances_backfill(self, dag_maker): + dag_id = 'SchedulerJobTest.test_find_executable_task_instances_backfill' task_id_1 = 'dummy' with dag_maker(dag_id=dag_id, max_active_tasks=16) as dag: task1 = DummyOperator(task_id=task_id_1) @@ -354,24 +322,20 @@ def test_find_executable_task_instances_backfill_nodagrun(self, dag_maker): state=State.RUNNING, ) - ti_no_dagrun = TaskInstance(task1, DEFAULT_DATE - datetime.timedelta(days=1)) - ti_backfill = TaskInstance(task1, dr2.execution_date) - ti_with_dagrun = TaskInstance(task1, dr1.execution_date) + ti_backfill = dr2.get_task_instance(task1.task_id) + ti_with_dagrun = dr1.get_task_instance(task1.task_id) # ti_with_paused - ti_no_dagrun.state = State.SCHEDULED ti_backfill.state = State.SCHEDULED ti_with_dagrun.state = State.SCHEDULED session.merge(dr2) - session.merge(ti_no_dagrun) session.merge(ti_backfill) session.merge(ti_with_dagrun) session.flush() res = self.scheduler_job._executable_task_instances_to_queued(max_tis=32, session=session) - assert 2 == len(res) + assert 1 == len(res) res_keys = map(lambda x: x.key, res) - assert ti_no_dagrun.key in res_keys assert ti_with_dagrun.key in res_keys session.rollback() @@ -379,26 +343,21 @@ def test_find_executable_task_instances_pool(self, dag_maker): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_pool' task_id_1 = 'dummy' task_id_2 = 'dummydummy' - with dag_maker(dag_id=dag_id, max_active_tasks=16) as dag: - task1 = DummyOperator(task_id=task_id_1, pool='a') - task2 = DummyOperator(task_id=task_id_2, pool='b') + session = settings.Session() + with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session) as dag: + DummyOperator(task_id=task_id_1, pool='a') + DummyOperator(task_id=task_id_2, pool='b') self.scheduler_job = SchedulerJob(subdir=os.devnull) - session = settings.Session() dr1 = dag_maker.create_dagrun() - dr2 = dag.create_dagrun( + dr2 = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=dag.following_schedule(dr1.execution_date), state=State.RUNNING, ) - tis = [ - TaskInstance(task1, dr1.execution_date), - TaskInstance(task2, dr1.execution_date), - TaskInstance(task1, dr2.execution_date), - TaskInstance(task2, dr2.execution_date), - ] + tis = dr1.task_instances + dr2.task_instances for ti in tis: ti.state = State.SCHEDULED session.merge(ti) @@ -428,21 +387,20 @@ def test_find_executable_task_instances_order_execution_date(self, dag_maker): dag_id_1 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date-a' dag_id_2 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date-b' task_id = 'task-a' - with dag_maker(dag_id=dag_id_1, max_active_tasks=16): - dag1_task = DummyOperator(task_id=task_id) + session = settings.Session() + with dag_maker(dag_id=dag_id_1, max_active_tasks=16, session=session): + DummyOperator(task_id=task_id) dr1 = dag_maker.create_dagrun(execution_date=DEFAULT_DATE + timedelta(hours=1)) - with dag_maker(dag_id=dag_id_2, max_active_tasks=16): - dag2_task = DummyOperator(task_id=task_id) + with dag_maker(dag_id=dag_id_2, max_active_tasks=16, session=session): + DummyOperator(task_id=task_id) dr2 = dag_maker.create_dagrun() + dr1 = session.merge(dr1, load=False) + self.scheduler_job = SchedulerJob(subdir=os.devnull) - session = settings.Session() - tis = [ - TaskInstance(dag1_task, dr1.execution_date), - TaskInstance(dag2_task, dr2.execution_date), - ] + tis = dr1.task_instances + dr2.task_instances for ti in tis: ti.state = State.SCHEDULED session.merge(ti) @@ -457,21 +415,20 @@ def test_find_executable_task_instances_order_priority(self, dag_maker): dag_id_1 = 'SchedulerJobTest.test_find_executable_task_instances_order_priority-a' dag_id_2 = 'SchedulerJobTest.test_find_executable_task_instances_order_priority-b' task_id = 'task-a' - with dag_maker(dag_id=dag_id_1, max_active_tasks=16): - dag1_task = DummyOperator(task_id=task_id, priority_weight=1) + session = settings.Session() + with dag_maker(dag_id=dag_id_1, max_active_tasks=16, session=session): + DummyOperator(task_id=task_id, priority_weight=1) dr1 = dag_maker.create_dagrun() - with dag_maker(dag_id=dag_id_2, max_active_tasks=16): - dag2_task = DummyOperator(task_id=task_id, priority_weight=4) + with dag_maker(dag_id=dag_id_2, max_active_tasks=16, session=session): + DummyOperator(task_id=task_id, priority_weight=4) dr2 = dag_maker.create_dagrun() + dr1 = session.merge(dr1, load=False) + self.scheduler_job = SchedulerJob(subdir=os.devnull) - session = settings.Session() - tis = [ - TaskInstance(dag1_task, dr1.execution_date), - TaskInstance(dag2_task, dr2.execution_date), - ] + tis = dr1.task_instances + dr2.task_instances for ti in tis: ti.state = State.SCHEDULED session.merge(ti) @@ -486,21 +443,19 @@ def test_find_executable_task_instances_order_execution_date_and_priority(self, dag_id_1 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date_and_priority-a' dag_id_2 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date_and_priority-b' task_id = 'task-a' - with dag_maker(dag_id=dag_id_1, max_active_tasks=16): - dag1_task = DummyOperator(task_id=task_id, priority_weight=1) + session = settings.Session() + with dag_maker(dag_id=dag_id_1, max_active_tasks=16, session=session): + DummyOperator(task_id=task_id, priority_weight=1) dr1 = dag_maker.create_dagrun() - with dag_maker(dag_id=dag_id_2, max_active_tasks=16): - dag2_task = DummyOperator(task_id=task_id, priority_weight=4) + with dag_maker(dag_id=dag_id_2, max_active_tasks=16, session=session): + DummyOperator(task_id=task_id, priority_weight=4) dr2 = dag_maker.create_dagrun(execution_date=DEFAULT_DATE + timedelta(hours=1)) + dr1 = session.merge(dr1, load=False) self.scheduler_job = SchedulerJob(subdir=os.devnull) - session = settings.Session() - tis = [ - TaskInstance(dag1_task, dr1.execution_date), - TaskInstance(dag2_task, dr2.execution_date), - ] + tis = dr1.task_instances + dr2.task_instances for ti in tis: ti.state = State.SCHEDULED session.merge(ti) @@ -530,13 +485,11 @@ def test_find_executable_task_instances_in_default_pool(self, dag_maker): state=State.RUNNING, ) - ti1 = TaskInstance(task=op1, execution_date=dr1.execution_date) - ti2 = TaskInstance(task=op2, execution_date=dr2.execution_date) + ti1 = dr1.get_task_instance(op1.task_id, session) + ti2 = dr2.get_task_instance(op2.task_id, session) ti1.state = State.SCHEDULED ti2.state = State.SCHEDULED - session.merge(ti1) - session.merge(ti2) session.flush() # Two tasks w/o pool up for execution and our default pool size is 1 @@ -544,7 +497,6 @@ def test_find_executable_task_instances_in_default_pool(self, dag_maker): assert 1 == len(res) ti2.state = State.RUNNING - session.merge(ti2) session.flush() # One task w/o pool up for execution and one task running @@ -556,16 +508,15 @@ def test_find_executable_task_instances_in_default_pool(self, dag_maker): def test_nonexistent_pool(self, dag_maker): dag_id = 'SchedulerJobTest.test_nonexistent_pool' - task_id = 'dummy_wrong_pool' with dag_maker(dag_id=dag_id, max_active_tasks=16): - task = DummyOperator(task_id=task_id, pool="this_pool_doesnt_exist") + DummyOperator(task_id="dummy_wrong_pool", pool="this_pool_doesnt_exist") self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() dr = dag_maker.create_dagrun() - ti = TaskInstance(task, dr.execution_date) + ti = dr.task_instances[0] ti.state = State.SCHEDULED session.merge(ti) session.commit() @@ -577,15 +528,14 @@ def test_nonexistent_pool(self, dag_maker): def test_infinite_pool(self, dag_maker): dag_id = 'SchedulerJobTest.test_infinite_pool' - task_id = 'dummy' with dag_maker(dag_id=dag_id, concurrency=16): - task = DummyOperator(task_id=task_id, pool="infinite_pool") + DummyOperator(task_id="dummy", pool="infinite_pool") self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() dr = dag_maker.create_dagrun() - ti = TaskInstance(task, dr.execution_date) + ti = dr.task_instances[0] ti.state = State.SCHEDULED session.merge(ti) infinite_pool = Pool(pool='infinite_pool', slots=-1, description='infinite pool') @@ -642,28 +592,27 @@ def test_tis_for_queued_dagruns_are_not_run(self, dag_maker): def test_find_executable_task_instances_concurrency(self, dag_maker): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_concurrency' - task_id_1 = 'dummy' - with dag_maker(dag_id=dag_id, max_active_tasks=2) as dag: - task1 = DummyOperator(task_id=task_id_1) + session = settings.Session() + with dag_maker(dag_id=dag_id, max_active_tasks=2, session=session) as dag: + DummyOperator(task_id='dummy') self.scheduler_job = SchedulerJob(subdir=os.devnull) - session = settings.Session() dr1 = dag_maker.create_dagrun() - dr2 = dag.create_dagrun( + dr2 = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=dag.following_schedule(dr1.execution_date), state=State.RUNNING, ) - dr3 = dag.create_dagrun( + dr3 = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=dag.following_schedule(dr2.execution_date), state=State.RUNNING, ) - ti1 = TaskInstance(task1, dr1.execution_date) - ti2 = TaskInstance(task1, dr2.execution_date) - ti3 = TaskInstance(task1, dr3.execution_date) + ti1 = dr1.task_instances[0] + ti2 = dr2.task_instances[0] + ti3 = dr3.task_instances[0] ti1.state = State.RUNNING ti2.state = State.SCHEDULED ti3.state = State.SCHEDULED @@ -700,9 +649,9 @@ def test_find_executable_task_instances_concurrency_queued(self, dag_maker): dag_run = dag_maker.create_dagrun() - ti1 = TaskInstance(task1, dag_run.execution_date) - ti2 = TaskInstance(task2, dag_run.execution_date) - ti3 = TaskInstance(task3, dag_run.execution_date) + ti1 = dag_run.get_task_instance(task1.task_id) + ti2 = dag_run.get_task_instance(task2.task_id) + ti3 = dag_run.get_task_instance(task3.task_id) ti1.state = State.RUNNING ti2.state = State.QUEUED ti3.state = State.SCHEDULED @@ -744,8 +693,8 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): state=State.RUNNING, ) - ti1_1 = TaskInstance(task1, dr1.execution_date) - ti2 = TaskInstance(task2, dr1.execution_date) + ti1_1 = dr1.get_task_instance(task1.task_id) + ti2 = dr1.get_task_instance(task2.task_id) ti1_1.state = State.SCHEDULED ti2.state = State.SCHEDULED @@ -759,7 +708,7 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): ti1_1.state = State.RUNNING ti2.state = State.RUNNING - ti1_2 = TaskInstance(task1, dr2.execution_date) + ti1_2 = dr2.get_task_instance(task1.task_id) ti1_2.state = State.SCHEDULED session.merge(ti1_1) session.merge(ti2) @@ -771,7 +720,7 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): assert 1 == len(res) ti1_2.state = State.RUNNING - ti1_3 = TaskInstance(task1, dr3.execution_date) + ti1_3 = dr3.get_task_instance(task1.task_id) ti1_3.state = State.SCHEDULED session.merge(ti1_2) session.merge(ti1_3) @@ -830,9 +779,9 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self, dag_ state=State.RUNNING, ) - ti1 = TaskInstance(task1, dr1.execution_date) - ti2 = TaskInstance(task1, dr2.execution_date) - ti3 = TaskInstance(task1, dr3.execution_date) + ti1 = dr1.get_task_instance(task1.task_id) + ti2 = dr2.get_task_instance(task1.task_id) + ti3 = dr3.get_task_instance(task1.task_id) ti1.state = State.RUNNING ti2.state = State.RUNNING ti3.state = State.RUNNING @@ -850,18 +799,14 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self, dag_ def test_enqueue_task_instances_with_queued_state(self, dag_maker): dag_id = 'SchedulerJobTest.test_enqueue_task_instances_with_queued_state' task_id_1 = 'dummy' - with dag_maker(dag_id=dag_id, start_date=DEFAULT_DATE): + session = settings.Session() + with dag_maker(dag_id=dag_id, start_date=DEFAULT_DATE, session=session): task1 = DummyOperator(task_id=task_id_1) self.scheduler_job = SchedulerJob(subdir=os.devnull) - session = settings.Session() - - dag_model = dag_maker.dag_model dr1 = dag_maker.create_dagrun() - ti1 = TaskInstance(task1, dr1.execution_date) - ti1.dag_model = dag_model - session.merge(ti1) + ti1 = dr1.get_task_instance(task1.task_id, session) with patch.object(BaseExecutor, 'queue_command') as mock_queue_command: self.scheduler_job._enqueue_task_instances_with_queued_state([ti1]) @@ -873,49 +818,41 @@ def test_critical_section_execute_task_instances(self, dag_maker): dag_id = 'SchedulerJobTest.test_execute_task_instances' task_id_1 = 'dummy_task' task_id_2 = 'dummy_task_nonexistent_queue' + session = settings.Session() # important that len(tasks) is less than max_active_tasks # because before scheduler._execute_task_instances would only # check the num tasks once so if max_active_tasks was 3, # we could execute arbitrarily many tasks in the second run - with dag_maker(dag_id=dag_id, max_active_tasks=3) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=3, session=session) as dag: task1 = DummyOperator(task_id=task_id_1) task2 = DummyOperator(task_id=task_id_2) self.scheduler_job = SchedulerJob(subdir=os.devnull) - session = settings.Session() # create first dag run with 1 running and 1 queued dr1 = dag_maker.create_dagrun() - ti1 = TaskInstance(task1, dr1.execution_date) - ti2 = TaskInstance(task2, dr1.execution_date) - ti1.refresh_from_db() - ti2.refresh_from_db() + ti1 = dr1.get_task_instance(task1.task_id, session) + ti2 = dr1.get_task_instance(task2.task_id, session) ti1.state = State.RUNNING ti2.state = State.RUNNING - session.merge(ti1) - session.merge(ti2) session.flush() assert State.RUNNING == dr1.state assert 2 == DAG.get_num_task_instances(dag_id, dag.task_ids, states=[State.RUNNING], session=session) # create second dag run - dr2 = dag.create_dagrun( + dr2 = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=dag.following_schedule(dr1.execution_date), state=State.RUNNING, ) - ti3 = TaskInstance(task1, dr2.execution_date) - ti4 = TaskInstance(task2, dr2.execution_date) - ti3.refresh_from_db() - ti4.refresh_from_db() + ti3 = dr2.get_task_instance(task1.task_id, session) + ti4 = dr2.get_task_instance(task2.task_id, session) # manually set to scheduled so we can pick them up ti3.state = State.SCHEDULED ti4.state = State.SCHEDULED - session.merge(ti3) - session.merge(ti4) session.flush() assert State.RUNNING == dr2.state @@ -939,36 +876,30 @@ def test_execute_task_instances_limit(self, dag_maker): dag_id = 'SchedulerJobTest.test_execute_task_instances_limit' task_id_1 = 'dummy_task' task_id_2 = 'dummy_task_2' + session = settings.Session() # important that len(tasks) is less than max_active_tasks # because before scheduler._execute_task_instances would only # check the num tasks once so if max_active_tasks was 3, # we could execute arbitrarily many tasks in the second run - with dag_maker(dag_id=dag_id, max_active_tasks=16) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session) as dag: task1 = DummyOperator(task_id=task_id_1) task2 = DummyOperator(task_id=task_id_2) self.scheduler_job = SchedulerJob(subdir=os.devnull) - session = settings.Session() date = dag.start_date tis = [] for _ in range(0, 4): date = dag.following_schedule(date) - dr = dag.create_dagrun( + dr = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=date, state=State.RUNNING, ) - ti1 = TaskInstance(task1, dr.execution_date) - ti2 = TaskInstance(task2, dr.execution_date) - tis.append(ti1) - tis.append(ti2) - ti1.refresh_from_db() - ti2.refresh_from_db() + ti1 = dr.get_task_instance(task1.task_id, session) + ti2 = dr.get_task_instance(task2.task_id, session) ti1.state = State.SCHEDULED ti2.state = State.SCHEDULED - session.merge(ti1) - session.merge(ti2) session.flush() self.scheduler_job.max_tis_per_query = 2 res = self.scheduler_job._critical_section_execute_task_instances(session) @@ -995,33 +926,27 @@ def test_execute_task_instances_unlimited(self, dag_maker): dag_id = 'SchedulerJobTest.test_execute_task_instances_unlimited' task_id_1 = 'dummy_task' task_id_2 = 'dummy_task_2' + session = settings.Session() - with dag_maker(dag_id=dag_id, max_active_tasks=1024) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=1024, session=session) as dag: task1 = DummyOperator(task_id=task_id_1) task2 = DummyOperator(task_id=task_id_2) self.scheduler_job = SchedulerJob(subdir=os.devnull) - session = settings.Session() date = dag.start_date - tis = [] for _ in range(0, 20): date = dag.following_schedule(date) - dr = dag.create_dagrun( + dr = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=date, state=State.RUNNING, ) - ti1 = TaskInstance(task1, dr.execution_date) - ti2 = TaskInstance(task2, dr.execution_date) - tis.append(ti1) - tis.append(ti2) - ti1.refresh_from_db() - ti2.refresh_from_db() + date = dag.following_schedule(date) + ti1 = dr.get_task_instance(task1.task_id, session) + ti2 = dr.get_task_instance(task2.task_id, session) ti1.state = State.SCHEDULED ti2.state = State.SCHEDULED - session.merge(ti1) - session.merge(ti2) session.flush() self.scheduler_job.max_tis_per_query = 0 self.scheduler_job.executor = MagicMock(slots_available=36) @@ -1031,86 +956,6 @@ def test_execute_task_instances_unlimited(self, dag_maker): assert res == 36 session.rollback() - def test_change_state_for_tis_without_dagrun(self, dag_maker): - with dag_maker(dag_id='test_change_state_for_tis_without_dagrun'): - DummyOperator(task_id='dummy') - DummyOperator(task_id='dummy_b') - dr1 = dag_maker.create_dagrun() - - with dag_maker(dag_id='test_change_state_for_tis_without_dagrun_dont_change'): - DummyOperator(task_id='dummy') - dr2 = dag_maker.create_dagrun() - - # Using dag_maker for below dag will create a dagrun and we don't want a dagrun - with dag_maker(dag_id='test_change_state_for_tis_without_dagrun_no_dagrun') as dag3: - DummyOperator(task_id='dummy') - - session = settings.Session() - - ti1a = dr1.get_task_instance(task_id='dummy', session=session) - ti1a.state = State.SCHEDULED - ti1b = dr1.get_task_instance(task_id='dummy_b', session=session) - ti1b.state = State.SUCCESS - session.commit() - - ti2 = dr2.get_task_instance(task_id='dummy', session=session) - ti2.state = State.SCHEDULED - session.commit() - - ti3 = TaskInstance(dag3.get_task('dummy'), DEFAULT_DATE) - ti3.state = State.SCHEDULED - session.merge(ti3) - session.commit() - - self.scheduler_job = SchedulerJob(num_runs=0) - self.scheduler_job.dagbag.collect_dags_from_db() - - self.scheduler_job._change_state_for_tis_without_dagrun( - old_states=[State.SCHEDULED, State.QUEUED], new_state=State.NONE, session=session - ) - - ti1a = dr1.get_task_instance(task_id='dummy', session=session) - ti1a.refresh_from_db(session=session) - assert ti1a.state == State.SCHEDULED - - ti1b = dr1.get_task_instance(task_id='dummy_b', session=session) - ti1b.refresh_from_db(session=session) - assert ti1b.state == State.SUCCESS - - ti2 = dr2.get_task_instance(task_id='dummy', session=session) - ti2.refresh_from_db(session=session) - assert ti2.state == State.SCHEDULED - - ti3.refresh_from_db(session=session) - assert ti3.state == State.NONE - assert ti3.start_date is not None - assert ti3.end_date is None - assert ti3.duration is None - - dr1.refresh_from_db(session=session) - dr1.state = State.FAILED - - # Push the changes to DB - session.merge(dr1) - session.commit() - - self.scheduler_job._change_state_for_tis_without_dagrun( - old_states=[State.SCHEDULED, State.QUEUED], new_state=State.NONE, session=session - ) - - # Clear the session objects - session.expunge_all() - ti1a.refresh_from_db(session=session) - assert ti1a.state == State.NONE - - # don't touch ti1b - ti1b.refresh_from_db(session=session) - assert ti1b.state == State.SUCCESS - - # don't touch ti2 - ti2.refresh_from_db(session=session) - assert ti2.state == State.SCHEDULED - def test_adopt_or_reset_orphaned_tasks(self, dag_maker): session = settings.Session() with dag_maker('test_execute_helper_reset_orphaned_tasks') as dag: @@ -1143,60 +988,6 @@ def test_adopt_or_reset_orphaned_tasks(self, dag_maker): ti2 = dr2.get_task_instance(task_id=op1.task_id, session=session) assert ti2.state == State.QUEUED, "Tasks run by Backfill Jobs should not be reset" - @pytest.mark.parametrize( - "initial_task_state, expected_task_state", - [ - [State.UP_FOR_RETRY, State.FAILED], - [State.QUEUED, State.NONE], - [State.SCHEDULED, State.NONE], - [State.UP_FOR_RESCHEDULE, State.NONE], - ], - ) - def test_scheduler_loop_should_change_state_for_tis_without_dagrun( - self, initial_task_state, expected_task_state, dag_maker - ): - session = settings.Session() - dag_id = 'test_execute_helper_should_change_state_for_tis_without_dagrun' - with dag_maker( - dag_id, - start_date=DEFAULT_DATE + timedelta(days=1), - ): - op1 = DummyOperator(task_id='op1') - - # Create DAG run with FAILED state - dr = dag_maker.create_dagrun( - state=State.FAILED, - execution_date=DEFAULT_DATE + timedelta(days=1), - start_date=DEFAULT_DATE + timedelta(days=1), - ) - ti = dr.get_task_instance(task_id=op1.task_id, session=session) - ti.state = initial_task_state - session.commit() - - # This poll interval is large, bug the scheduler doesn't sleep that - # long, instead we hit the clean_tis_without_dagrun interval instead - self.scheduler_job = SchedulerJob(num_runs=2, processor_poll_interval=30) - self.scheduler_job.dagbag = dag_maker.dagbag - executor = MockExecutor(do_update=False) - executor.queued_tasks - self.scheduler_job.executor = executor - processor = mock.MagicMock() - processor.done = False - self.scheduler_job.processor_agent = processor - - with mock.patch.object(settings, "USE_JOB_SCHEDULE", False), conf_vars( - {('scheduler', 'clean_tis_without_dagrun_interval'): '0.001'} - ): - self.scheduler_job._run_scheduler_loop() - - ti = dr.get_task_instance(task_id=op1.task_id, session=session) - assert ti.state == expected_task_state - assert ti.start_date is not None - if expected_task_state in State.finished: - assert ti.end_date is not None - assert ti.start_date == ti.end_date - assert ti.duration is not None - @mock.patch('airflow.jobs.scheduler_job.DagFileProcessorAgent') def test_executor_end_called(self, mock_processor_agent): """ @@ -1289,7 +1080,7 @@ def test_dagrun_timeout_verify_max_active_runs(self, dag_maker): full_filepath=dr.dag.fileloc, dag_id=dr.dag_id, is_failure_callback=True, - execution_date=dr.execution_date, + run_id=dr.run_id, msg="timed_out", ) @@ -1330,7 +1121,7 @@ def test_dagrun_timeout_fails_run(self, dag_maker): full_filepath=dr.dag.fileloc, dag_id=dr.dag_id, is_failure_callback=True, - execution_date=dr.execution_date, + run_id=dr.run_id, msg="timed_out", ) @@ -1374,7 +1165,7 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak full_filepath=dag.fileloc, dag_id=dr.dag_id, is_failure_callback=bool(state == State.FAILED), - execution_date=dr.execution_date, + run_id=dr.run_id, msg=expected_callback_msg, ) @@ -1531,7 +1322,7 @@ def evaluate_dagrun( for tid, state in expected_task_states.items(): if state != State.FAILED: continue - self.null_exec.mock_task_fail(dag_id, tid, ex_date) + self.null_exec.mock_task_fail(dag_id, tid, dr.run_id) try: dag = DagBag().get_dag(dag.dag_id) @@ -1541,13 +1332,6 @@ def evaluate_dagrun( except AirflowException: pass - # test tasks - for task_id, expected_state in expected_task_states.items(): - task = dag.get_task(task_id) - ti = TaskInstance(task, ex_date) - ti.refresh_from_db() - assert ti.state == expected_state - # load dagrun dr = DagRun.find(dag_id=dag_id, execution_date=ex_date) dr = dr[0] @@ -1555,6 +1339,11 @@ def evaluate_dagrun( assert dr.state == dagrun_state + # test tasks + for task_id, expected_state in expected_task_states.items(): + ti = dr.get_task_instance(task_id) + assert ti.state == expected_state + def test_dagrun_fail(self): """ DagRuns with one failed and one incomplete root task -> FAILED @@ -1607,7 +1396,7 @@ def test_dagrun_root_fail_unfinished(self): execution_date=DEFAULT_DATE, state=State.RUNNING, ) - self.null_exec.mock_task_fail(dag_id, 'test_dagrun_fail', DEFAULT_DATE) + self.null_exec.mock_task_fail(dag_id, 'test_dagrun_fail', dr.run_id) with pytest.raises(AirflowException): dag.run(start_date=dr.execution_date, end_date=dr.execution_date, executor=self.null_exec) @@ -1710,7 +1499,10 @@ def test_scheduler_start_date(self): # one task ran assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 1 assert [ - (TaskInstanceKey(dag.dag_id, 'dummy', DEFAULT_DATE, 1), (State.SUCCESS, None)), + ( + TaskInstanceKey(dag.dag_id, 'dummy', f'backfill__{DEFAULT_DATE.isoformat()}', 1), + (State.SUCCESS, None), + ), ] == bf_exec.sorted_tasks session.commit() @@ -1835,13 +1627,14 @@ def test_scheduler_verify_pool_full(self, dag_maker): assert len(task_instances_list) == 1 - def test_scheduler_verify_pool_full_2_slots_per_task(self, dag_maker): + @pytest.mark.need_serialized_dag + def test_scheduler_verify_pool_full_2_slots_per_task(self, dag_maker, session): """ Test task instances not queued when pool is full. Variation with non-default pool_slots """ - with dag_maker(dag_id='test_scheduler_verify_pool_full_2_slots_per_task') as dag: + with dag_maker(dag_id='test_scheduler_verify_pool_full_2_slots_per_task', session=session) as dag: BashOperator( task_id='dummy', pool='test_scheduler_verify_pool_full_2_slots_per_task', @@ -1849,7 +1642,6 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self, dag_maker): bash_command='echo hi', ) - session = settings.Session() pool = Pool(pool='test_scheduler_verify_pool_full_2_slots_per_task', slots=6) session.add(pool) session.flush() @@ -1865,6 +1657,7 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self, dag_maker): run_type=DagRunType.SCHEDULED, execution_date=date, state=State.RUNNING, + session=session, ) self.scheduler_job._schedule_dag_run(dr, session) @@ -1979,7 +1772,10 @@ def test_scheduler_verify_priority_and_slots(self, dag_maker): self.scheduler_job.processor_agent = mock.MagicMock() dr = dag_maker.create_dagrun() - self.scheduler_job._schedule_dag_run(dr, session) + for ti in dr.task_instances: + ti.state = State.SCHEDULED + session.merge(ti) + session.flush() task_instances_list = self.scheduler_job._executable_task_instances_to_queued( max_tis=32, session=session @@ -2358,25 +2154,6 @@ def test_adopt_or_reset_orphaned_tasks_backfill_dag(self, dag_maker): assert 0 == self.scheduler_job.adopt_or_reset_orphaned_tasks(session=session) session.rollback() - def test_reset_orphaned_tasks_nonexistent_dagrun(self, dag_maker): - """Make sure a task in an orphaned state is not reset if it has no dagrun.""" - dag_id = 'test_reset_orphaned_tasks_nonexistent_dagrun' - with dag_maker(dag_id=dag_id, schedule_interval='@daily'): - task_id = dag_id + '_task' - task = DummyOperator(task_id=task_id) - - self.scheduler_job = SchedulerJob(subdir=os.devnull) - session = settings.Session() - - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) - ti.refresh_from_db() - ti.state = State.SCHEDULED - session.merge(ti) - session.flush() - - assert 0 == self.scheduler_job.adopt_or_reset_orphaned_tasks(session=session) - session.rollback() - def test_reset_orphaned_tasks_no_orphans(self, dag_maker): dag_id = 'test_reset_orphaned_tasks_no_orphans' with dag_maker(dag_id=dag_id, schedule_interval='@daily'): @@ -2770,58 +2547,52 @@ def test_do_schedule_max_active_runs_dag_timed_out(self, dag_maker): session=session, ) - dag.sync_to_db(session=session) self.scheduler_job = SchedulerJob(subdir=os.devnull) self.scheduler_job.executor = MockExecutor() self.scheduler_job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) self.scheduler_job._do_scheduling(session) - session.add(run1) + run1 = session.merge(run1) session.refresh(run1) assert run1.state == State.FAILED assert run1_ti.state == State.SKIPPED # Run scheduling again to assert run2 has started self.scheduler_job._do_scheduling(session) - session.add(run2) + run2 = session.merge(run2) session.refresh(run2) assert run2.state == State.RUNNING run2_ti = run2.get_task_instance(task1.task_id, session) assert run2_ti.state == State.QUEUED - def test_do_schedule_max_active_runs_task_removed(self, dag_maker): + def test_do_schedule_max_active_runs_task_removed(self, session, dag_maker): """Test that tasks in removed state don't count as actively running.""" - with dag_maker( dag_id='test_do_schedule_max_active_runs_task_removed', start_date=DEFAULT_DATE, schedule_interval='@once', max_active_runs=1, - ) as dag: + session=session, + ): # Can't use DummyOperator as that goes straight to success - task1 = BashOperator(task_id='dummy1', bash_command='true') - - session = settings.Session() - session.add(TaskInstance(task1, DEFAULT_DATE, State.REMOVED)) - session.flush() + BashOperator(task_id='dummy1', bash_command='true') - run1 = dag.create_dagrun( + run1 = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE + timedelta(hours=1), state=State.RUNNING, - session=session, ) - dag.sync_to_db(session=session) # Update the date fields - self.scheduler_job = SchedulerJob(subdir=os.devnull) self.scheduler_job.executor = MockExecutor(do_update=False) self.scheduler_job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) num_queued = self.scheduler_job._do_scheduling(session) - assert num_queued == 1 - ti = run1.get_task_instance(task1.task_id, session) + + session.flush() + ti = run1.task_instances[0] + ti.refresh_from_db(session=session) assert ti.state == State.QUEUED def test_do_schedule_max_active_runs_and_manual_trigger(self, dag_maker): @@ -2858,7 +2629,7 @@ def test_do_schedule_max_active_runs_and_manual_trigger(self, dag_maker): num_queued = self.scheduler_job._do_scheduling(session) # Add it back in to the session so we can refresh it. (_do_scheduling does an expunge_all to reduce # memory) - session.add(dag_run) + dag_run = session.merge(dag_run) session.refresh(dag_run) assert num_queued == 2 diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py index 6718450015570..5adc91f527c97 100644 --- a/tests/jobs/test_triggerer_job.py +++ b/tests/jobs/test_triggerer_job.py @@ -22,10 +22,8 @@ import pytest -from airflow import DAG from airflow.jobs.triggerer_job import TriggererJob from airflow.models import Trigger -from airflow.models.taskinstance import TaskInstance from airflow.operators.dummy import DummyOperator from airflow.triggers.base import TriggerEvent from airflow.triggers.temporal import TimeDeltaTrigger @@ -293,7 +291,7 @@ def test_trigger_cleanup(session): @pytest.mark.skipif(sys.version_info.minor <= 6 and sys.version_info.major <= 3, reason="No triggerer on 3.6") -def test_invalid_trigger(session): +def test_invalid_trigger(session, dag_maker): """ Checks that the triggerer will correctly fail task instances that depend on triggers that can't even be loaded. @@ -305,22 +303,14 @@ def test_invalid_trigger(session): session.commit() # Create the test DAG and task - with DAG( - dag_id='test_invalid_trigger', - start_date=timezone.datetime(2016, 1, 1), - schedule_interval='@once', - max_active_runs=1, - ): - task1 = DummyOperator(task_id='dummy1') + with dag_maker(dag_id='test_invalid_trigger', session=session): + DummyOperator(task_id='dummy1') + dr = dag_maker.create_dagrun() + task_instance = dr.task_instances[0] # Make a task instance based on that and tie it to the trigger - task_instance = TaskInstance( - task1, - execution_date=timezone.datetime(2016, 1, 1), - state=TaskInstanceState.DEFERRED, - ) + task_instance.state = TaskInstanceState.DEFERRED task_instance.trigger_id = 1 - session.add(task_instance) session.commit() # Make a TriggererJob and have it retrieve DB tasks diff --git a/tests/lineage/test_lineage.py b/tests/lineage/test_lineage.py index b5ebbea1efa0e..171fad1ce0932 100644 --- a/tests/lineage/test_lineage.py +++ b/tests/lineage/test_lineage.py @@ -15,15 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import unittest from unittest import mock from airflow.lineage import AUTO, apply_lineage, get_backend, prepare_lineage from airflow.lineage.backend import LineageBackend from airflow.lineage.entities import File -from airflow.models import DAG, TaskInstance as TI +from airflow.models import TaskInstance as TI from airflow.operators.dummy import DummyOperator from airflow.utils import timezone +from airflow.utils.types import DagRunType from tests.test_utils.config import conf_vars DEFAULT_DATE = timezone.datetime(2016, 1, 1) @@ -34,10 +34,8 @@ def send_lineage(self, operator, inlets=None, outlets=None, context=None): pass -class TestLineage(unittest.TestCase): - def test_lineage(self): - dag = DAG(dag_id='test_prepare_lineage', start_date=DEFAULT_DATE) - +class TestLineage: + def test_lineage(self, dag_maker): f1s = "/tmp/does_not_exist_1-{}" f2s = "/tmp/does_not_exist_2-{}" f3s = "/tmp/does_not_exist_3" @@ -45,7 +43,7 @@ def test_lineage(self): file2 = File(f2s.format("{{ execution_date }}")) file3 = File(f3s) - with dag: + with dag_maker(dag_id='test_prepare_lineage', start_date=DEFAULT_DATE) as dag: op1 = DummyOperator( task_id='leave1', inlets=file1, @@ -64,12 +62,13 @@ def test_lineage(self): op4.set_downstream(op5) dag.clear() + dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) # execution_date is set in the context in order to avoid creating task instances - ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE} - ctx2 = {"ti": TI(task=op2, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE} - ctx3 = {"ti": TI(task=op3, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE} - ctx5 = {"ti": TI(task=op5, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE} + ctx1 = {"ti": TI(task=op1, run_id=dag_run.run_id), "execution_date": DEFAULT_DATE} + ctx2 = {"ti": TI(task=op2, run_id=dag_run.run_id), "execution_date": DEFAULT_DATE} + ctx3 = {"ti": TI(task=op3, run_id=dag_run.run_id), "execution_date": DEFAULT_DATE} + ctx5 = {"ti": TI(task=op5, run_id=dag_run.run_id), "execution_date": DEFAULT_DATE} # prepare with manual inlets and outlets op1.pre_execute(ctx1) @@ -99,13 +98,12 @@ def test_lineage(self): assert len(op5.inlets) == 2 op5.post_execute(ctx5) - def test_lineage_render(self): + def test_lineage_render(self, dag_maker): # tests inlets / outlets are rendered if they are added # after initialization - dag = DAG(dag_id='test_lineage_render', start_date=DEFAULT_DATE) - - with dag: + with dag_maker(dag_id='test_lineage_render', start_date=DEFAULT_DATE): op1 = DummyOperator(task_id='task1') + dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) f1s = "/tmp/does_not_exist_1-{}" file1 = File(f1s.format("{{ execution_date }}")) @@ -114,14 +112,14 @@ def test_lineage_render(self): op1.outlets.append(file1) # execution_date is set in the context in order to avoid creating task instances - ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE} + ctx1 = {"ti": TI(task=op1, run_id=dag_run.run_id), "execution_date": DEFAULT_DATE} op1.pre_execute(ctx1) assert op1.inlets[0].url == f1s.format(DEFAULT_DATE) assert op1.outlets[0].url == f1s.format(DEFAULT_DATE) @mock.patch("airflow.lineage.get_backend") - def test_lineage_is_sent_to_backend(self, mock_get_backend): + def test_lineage_is_sent_to_backend(self, mock_get_backend, dag_maker): class TestBackend(LineageBackend): def send_lineage(self, operator, inlets=None, outlets=None, context=None): assert len(inlets) == 1 @@ -132,17 +130,17 @@ def send_lineage(self, operator, inlets=None, outlets=None, context=None): mock_get_backend.return_value = TestBackend() - dag = DAG(dag_id='test_lineage_is_sent_to_backend', start_date=DEFAULT_DATE) - - with dag: + with dag_maker(dag_id='test_lineage_is_sent_to_backend', start_date=DEFAULT_DATE): op1 = DummyOperator(task_id='task1') + dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) file1 = File("/tmp/some_file") op1.inlets.append(file1) op1.outlets.append(file1) - ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE} + (ti,) = dag_run.task_instances + ctx1 = {"ti": ti, "execution_date": DEFAULT_DATE} prep = prepare_lineage(func) prep(op1, ctx1) diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index bf3ea4a7f4481..a9662a94bfdd7 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -565,6 +565,16 @@ def hook(context, result): op_copy.post_execute({}) assert called + def test_task_naive_datetime(self): + naive_datetime = DEFAULT_DATE.replace(tzinfo=None) + + op_no_dag = DummyOperator( + task_id='test_task_naive_datetime', start_date=naive_datetime, end_date=naive_datetime + ) + + assert op_no_dag.start_date.tzinfo + assert op_no_dag.end_date.tzinfo + class CustomOp(DummyOperator): template_fields = ("field", "field2") diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py index 4f64347f8d604..e883bccac3e5c 100644 --- a/tests/models/test_cleartasks.py +++ b/tests/models/test_cleartasks.py @@ -17,44 +17,45 @@ # under the License. import datetime -import unittest -from parameterized import parameterized +import pytest from airflow import settings from airflow.models import DAG, TaskInstance as TI, TaskReschedule, clear_task_instances from airflow.operators.dummy import DummyOperator from airflow.sensors.python import PythonSensor from airflow.utils.session import create_session -from airflow.utils.state import State +from airflow.utils.state import State, TaskInstanceState from airflow.utils.types import DagRunType from tests.models import DEFAULT_DATE from tests.test_utils import db -class TestClearTasks(unittest.TestCase): - def setUp(self) -> None: +class TestClearTasks: + @pytest.fixture(autouse=True, scope="class") + def clean(self): db.clear_db_runs() - def tearDown(self): + yield + db.clear_db_runs() - def test_clear_task_instances(self): - dag = DAG( + def test_clear_task_instances(self, dag_maker): + with dag_maker( 'test_clear_task_instances', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10), - ) - task0 = DummyOperator(task_id='0', owner='test', dag=dag) - task1 = DummyOperator(task_id='1', owner='test', dag=dag, retries=2) - ti0 = TI(task=task0, execution_date=DEFAULT_DATE) - ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + ) as dag: + task0 = DummyOperator(task_id='0') + task1 = DummyOperator(task_id='1', retries=2) - dag.create_dagrun( - execution_date=ti0.execution_date, + dr = dag_maker.create_dagrun( state=State.RUNNING, run_type=DagRunType.SCHEDULED, ) + ti0, ti1 = dr.task_instances + ti0.refresh_from_task(task0) + ti1.refresh_from_task(task1) ti0.run() ti1.run() @@ -66,19 +67,22 @@ def test_clear_task_instances(self): ti0.refresh_from_db() ti1.refresh_from_db() # Next try to run will be try 2 + assert ti0.state is None assert ti0.try_number == 2 assert ti0.max_tries == 1 + assert ti1.state is None assert ti1.try_number == 2 assert ti1.max_tries == 3 - def test_clear_task_instances_external_executor_id(self): - dag = DAG( + def test_clear_task_instances_external_executor_id(self, dag_maker): + with dag_maker( 'test_clear_task_instances_external_executor_id', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10), - ) - task0 = DummyOperator(task_id='task0', owner='test', dag=dag) - ti0 = TI(task=task0, execution_date=DEFAULT_DATE) + ) as dag: + DummyOperator(task_id='task0') + + ti0 = dag_maker.create_dagrun().task_instances[0] ti0.state = State.SUCCESS ti0.external_executor_id = "some_external_executor_id" @@ -94,58 +98,60 @@ def test_clear_task_instances_external_executor_id(self): assert ti0.state is None assert ti0.external_executor_id is None - @parameterized.expand([(State.QUEUED, None), (State.RUNNING, DEFAULT_DATE)]) - def test_clear_task_instances_dr_state(self, state, last_scheduling): + @pytest.mark.parametrize( + ["state", "last_scheduling"], [(State.QUEUED, None), (State.RUNNING, DEFAULT_DATE)] + ) + def test_clear_task_instances_dr_state(self, state, last_scheduling, dag_maker): """Test that DR state is set to None after clear. And that DR.last_scheduling_decision is handled OK. start_date is also set to None """ - dag = DAG( + with dag_maker( 'test_clear_task_instances', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10), - ) - task0 = DummyOperator(task_id='0', owner='test', dag=dag) - task1 = DummyOperator(task_id='1', owner='test', dag=dag, retries=2) - ti0 = TI(task=task0, execution_date=DEFAULT_DATE) - ti1 = TI(task=task1, execution_date=DEFAULT_DATE) - session = settings.Session() - dr = dag.create_dagrun( - execution_date=ti0.execution_date, + ) as dag: + DummyOperator(task_id='0') + DummyOperator(task_id='1', retries=2) + dr = dag_maker.create_dagrun( state=State.RUNNING, run_type=DagRunType.SCHEDULED, ) + ti0, ti1 = dr.task_instances dr.last_scheduling_decision = DEFAULT_DATE - session.add(dr) - session.commit() + ti0.state = TaskInstanceState.SUCCESS + ti1.state = TaskInstanceState.SUCCESS + session = dag_maker.session + session.flush() - ti0.run() - ti1.run() qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() clear_task_instances(qry, session, dag_run_state=state, dag=dag) + session.flush() + + session.refresh(dr) - dr = ti0.get_dagrun() assert dr.state == state assert dr.start_date is None assert dr.last_scheduling_decision == last_scheduling - def test_clear_task_instances_without_task(self): - dag = DAG( + def test_clear_task_instances_without_task(self, dag_maker): + with dag_maker( 'test_clear_task_instances_without_task', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10), - ) - task0 = DummyOperator(task_id='task0', owner='test', dag=dag) - task1 = DummyOperator(task_id='task1', owner='test', dag=dag, retries=2) - ti0 = TI(task=task0, execution_date=DEFAULT_DATE) - ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + ) as dag: + task0 = DummyOperator(task_id='task0') + task1 = DummyOperator(task_id='task1', retries=2) - dag.create_dagrun( - execution_date=ti0.execution_date, + dr = dag_maker.create_dagrun( state=State.RUNNING, run_type=DagRunType.SCHEDULED, ) + ti0, ti1 = dr.task_instances + ti0.refresh_from_task(task0) + ti1.refresh_from_task(task1) + ti0.run() ti1.run() @@ -167,23 +173,24 @@ def test_clear_task_instances_without_task(self): assert ti1.try_number == 2 assert ti1.max_tries == 2 - def test_clear_task_instances_without_dag(self): - dag = DAG( + def test_clear_task_instances_without_dag(self, dag_maker): + with dag_maker( 'test_clear_task_instances_without_dag', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10), - ) - task0 = DummyOperator(task_id='task_0', owner='test', dag=dag) - task1 = DummyOperator(task_id='task_1', owner='test', dag=dag, retries=2) - ti0 = TI(task=task0, execution_date=DEFAULT_DATE) - ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + ) as dag: + task0 = DummyOperator(task_id='task0') + task1 = DummyOperator(task_id='task1', retries=2) - dag.create_dagrun( - execution_date=ti0.execution_date, + dr = dag_maker.create_dagrun( state=State.RUNNING, run_type=DagRunType.SCHEDULED, ) + ti0, ti1 = dr.task_instances + ti0.refresh_from_task(task0) + ti1.refresh_from_task(task1) + ti0.run() ti1.run() @@ -200,10 +207,10 @@ def test_clear_task_instances_without_dag(self): assert ti1.try_number == 2 assert ti1.max_tries == 2 - def test_clear_task_instances_with_task_reschedule(self): + def test_clear_task_instances_with_task_reschedule(self, dag_maker): """Test that TaskReschedules are deleted correctly when TaskInstances are cleared""" - with DAG( + with dag_maker( 'test_clear_task_instances_with_task_reschedule', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10), @@ -211,15 +218,14 @@ def test_clear_task_instances_with_task_reschedule(self): task0 = PythonSensor(task_id='0', python_callable=lambda: False, mode="reschedule") task1 = PythonSensor(task_id='1', python_callable=lambda: False, mode="reschedule") - ti0 = TI(task=task0, execution_date=DEFAULT_DATE) - ti1 = TI(task=task1, execution_date=DEFAULT_DATE) - - dag.create_dagrun( - execution_date=ti0.execution_date, + dr = dag_maker.create_dagrun( state=State.RUNNING, run_type=DagRunType.SCHEDULED, ) + ti0, ti1 = dr.task_instances + ti0.refresh_from_task(task0) + ti1.refresh_from_task(task1) ti0.run() ti1.run() @@ -231,7 +237,7 @@ def count_task_reschedule(task_id): .filter( TaskReschedule.dag_id == dag.dag_id, TaskReschedule.task_id == task_id, - TaskReschedule.execution_date == DEFAULT_DATE, + TaskReschedule.run_id == dr.run_id, TaskReschedule.try_number == 1, ) .count() @@ -244,22 +250,27 @@ def count_task_reschedule(task_id): assert count_task_reschedule(ti0.task_id) == 0 assert count_task_reschedule(ti1.task_id) == 1 - def test_dag_clear(self): - dag = DAG( + def test_dag_clear(self, dag_maker): + with dag_maker( 'test_dag_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10) - ) - task0 = DummyOperator(task_id='test_dag_clear_task_0', owner='test', dag=dag) - ti0 = TI(task=task0, execution_date=DEFAULT_DATE) + ) as dag: + task0 = DummyOperator(task_id='test_dag_clear_task_0') + task1 = DummyOperator(task_id='test_dag_clear_task_1', retries=2) - dag.create_dagrun( - execution_date=ti0.execution_date, + dr = dag_maker.create_dagrun( state=State.RUNNING, run_type=DagRunType.SCHEDULED, ) + session = dag_maker.session + + ti0, ti1 = dr.task_instances + ti0.refresh_from_task(task0) + ti1.refresh_from_task(task1) # Next try to run will be try 1 assert ti0.try_number == 1 ti0.run() + assert ti0.try_number == 2 dag.clear() ti0.refresh_from_db() @@ -267,12 +278,14 @@ def test_dag_clear(self): assert ti0.state == State.NONE assert ti0.max_tries == 1 - task1 = DummyOperator(task_id='test_dag_clear_task_1', owner='test', dag=dag, retries=2) - ti1 = TI(task=task1, execution_date=DEFAULT_DATE) assert ti1.max_tries == 2 ti1.try_number = 1 + session.merge(ti1) + session.commit() + # Next try will be 2 ti1.run() + assert ti1.try_number == 3 assert ti1.max_tries == 2 @@ -297,16 +310,16 @@ def test_dags_clear(self): start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10), ) - ti = TI( - task=DummyOperator(task_id='test_task_clear_' + str(i), owner='test', dag=dag), - execution_date=DEFAULT_DATE, - ) + task = DummyOperator(task_id='test_task_clear_' + str(i), owner='test', dag=dag) - dag.create_dagrun( - execution_date=ti.execution_date, + dr = dag.create_dagrun( + execution_date=DEFAULT_DATE, state=State.RUNNING, run_type=DagRunType.SCHEDULED, + session=session, ) + ti = dr.task_instances[0] + ti.task = task dags.append(dag) tis.append(ti) @@ -361,26 +374,25 @@ def test_dags_clear(self): assert tis[i].try_number == 3 assert tis[i].max_tries == 2 - def test_operator_clear(self): - dag = DAG( + def test_operator_clear(self, dag_maker): + with dag_maker( 'test_operator_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10), - ) - op1 = DummyOperator(task_id='bash_op', owner='test', dag=dag) - op2 = DummyOperator(task_id='dummy_op', owner='test', dag=dag, retries=1) - - op2.set_upstream(op1) + ): + op1 = DummyOperator(task_id='bash_op') + op2 = DummyOperator(task_id='dummy_op', retries=1) + op1 >> op2 - ti1 = TI(task=op1, execution_date=DEFAULT_DATE) - ti2 = TI(task=op2, execution_date=DEFAULT_DATE) - - dag.create_dagrun( - execution_date=ti1.execution_date, + dr = dag_maker.create_dagrun( state=State.RUNNING, run_type=DagRunType.SCHEDULED, ) + ti1, ti2 = dr.task_instances + ti1.task = op1 + ti2.task = op2 + ti2.run() # Dependency not met assert ti2.try_number == 1 diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 0ee0baaf24f2b..c6d54ed198942 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -36,6 +36,7 @@ from dateutil.relativedelta import relativedelta from freezegun import freeze_time from parameterized import parameterized +from sqlalchemy import inspect from airflow import models, settings from airflow.configuration import conf @@ -66,6 +67,13 @@ TEST_DATE = datetime_tz(2015, 1, 2, 0, 0) +@pytest.fixture +def session(): + with create_session() as session: + yield session + session.rollback() + + class TestDag(unittest.TestCase): def setUp(self) -> None: clear_db_runs() @@ -448,13 +456,24 @@ def test_get_num_task_instances(self): test_dag = DAG(dag_id=test_dag_id, start_date=DEFAULT_DATE) test_task = DummyOperator(task_id=test_task_id, dag=test_dag) - ti1 = TI(task=test_task, execution_date=DEFAULT_DATE) + dr1 = test_dag.create_dagrun(state=None, run_id="test1", execution_date=DEFAULT_DATE) + dr2 = test_dag.create_dagrun( + state=None, run_id="test2", execution_date=DEFAULT_DATE + datetime.timedelta(days=1) + ) + dr3 = test_dag.create_dagrun( + state=None, run_id="test3", execution_date=DEFAULT_DATE + datetime.timedelta(days=2) + ) + dr4 = test_dag.create_dagrun( + state=None, run_id="test4", execution_date=DEFAULT_DATE + datetime.timedelta(days=3) + ) + + ti1 = TI(task=test_task, run_id=dr1.run_id) ti1.state = None - ti2 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) + ti2 = TI(task=test_task, run_id=dr2.run_id) ti2.state = State.RUNNING - ti3 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=2)) + ti3 = TI(task=test_task, run_id=dr3.run_id) ti3.state = State.QUEUED - ti4 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=3)) + ti4 = TI(task=test_task, run_id=dr4.run_id) ti4.state = State.RUNNING session = settings.Session() session.merge(ti1) @@ -1094,10 +1113,12 @@ def test_dag_handle_callback_crash(self, mock_stats): when = TEST_DATE dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=when)) - dag_run = dag.create_dagrun(State.RUNNING, when, run_type=DagRunType.MANUAL) - # should not raise any exception - dag.handle_callback(dag_run, success=False) - dag.handle_callback(dag_run, success=True) + with create_session() as session: + dag_run = dag.create_dagrun(State.RUNNING, when, run_type=DagRunType.MANUAL, session=session) + + # should not raise any exception + dag.handle_callback(dag_run, success=False) + dag.handle_callback(dag_run, success=True) mock_stats.incr.assert_called_with("dag.callback_exceptions") @@ -1970,11 +1991,11 @@ def return_num(num): assert dag.params['value'] == self.VALUE -def test_set_task_instance_state(): +def test_set_task_instance_state(session, dag_maker): """Test that set_task_instance_state updates the TaskInstance state and clear downstream failed""" start_date = datetime_tz(2020, 1, 1) - with DAG("test_set_task_instance_state", start_date=start_date) as dag: + with dag_maker("test_set_task_instance_state", start_date=start_date, session=session) as dag: task_1 = DummyOperator(task_id="task_1") task_2 = DummyOperator(task_id="task_2") task_3 = DummyOperator(task_id="task_3") @@ -1983,49 +2004,48 @@ def test_set_task_instance_state(): task_1 >> [task_2, task_3, task_4, task_5] - dagrun = dag.create_dagrun( - start_date=start_date, execution_date=start_date, state=State.FAILED, run_type=DagRunType.SCHEDULED - ) + dagrun = dag_maker.create_dagrun(state=State.FAILED, run_type=DagRunType.SCHEDULED) - def get_task_instance(session, task): + def get_ti_from_db(task): return ( session.query(TI) .filter( TI.dag_id == dag.dag_id, TI.task_id == task.task_id, - TI.execution_date == start_date, + TI.run_id == dagrun.run_id, ) .one() ) - with create_session() as session: - get_task_instance(session, task_1).state = State.FAILED - get_task_instance(session, task_2).state = State.SUCCESS - get_task_instance(session, task_3).state = State.UPSTREAM_FAILED - get_task_instance(session, task_4).state = State.FAILED - get_task_instance(session, task_5).state = State.SKIPPED + get_ti_from_db(task_1).state = State.FAILED + get_ti_from_db(task_2).state = State.SUCCESS + get_ti_from_db(task_3).state = State.UPSTREAM_FAILED + get_ti_from_db(task_4).state = State.FAILED + get_ti_from_db(task_5).state = State.SKIPPED - session.commit() + session.flush() altered = dag.set_task_instance_state( - task_id=task_1.task_id, execution_date=start_date, state=State.SUCCESS + task_id=task_1.task_id, execution_date=start_date, state=State.SUCCESS, session=session ) - with create_session() as session: - # After _mark_task_instance_state, task_1 is marked as SUCCESS - assert get_task_instance(session, task_1).state == State.SUCCESS - # task_2 remains as SUCCESS - assert get_task_instance(session, task_2).state == State.SUCCESS - # task_3 and task_4 are cleared because they were in FAILED/UPSTREAM_FAILED state - assert get_task_instance(session, task_3).state == State.NONE - assert get_task_instance(session, task_4).state == State.NONE - # task_5 remains as SKIPPED - assert get_task_instance(session, task_5).state == State.SKIPPED - dagrun.refresh_from_db(session=session) - # dagrun should be set to QUEUED - assert dagrun.get_state() == State.QUEUED - - assert {t.key for t in altered} == {('test_set_task_instance_state', 'task_1', start_date, 1)} + # After _mark_task_instance_state, task_1 is marked as SUCCESS + ti1 = get_ti_from_db(task_1) + assert ti1.state == State.SUCCESS + # TIs should have DagRun pre-loaded + assert isinstance(inspect(ti1).attrs.dag_run.loaded_value, DagRun) + # task_2 remains as SUCCESS + assert get_ti_from_db(task_2).state == State.SUCCESS + # task_3 and task_4 are cleared because they were in FAILED/UPSTREAM_FAILED state + assert get_ti_from_db(task_3).state == State.NONE + assert get_ti_from_db(task_4).state == State.NONE + # task_5 remains as SKIPPED + assert get_ti_from_db(task_5).state == State.SKIPPED + dagrun.refresh_from_db(session=session) + # dagrun should be set to QUEUED + assert dagrun.get_state() == State.QUEUED + + assert {t.key for t in altered} == {('test_set_task_instance_state', 'task_1', dagrun.run_id, 1)} @pytest.mark.parametrize( diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index fb6a099532cc1..c9b9bcae35955 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -24,10 +24,8 @@ from parameterized import parameterized from airflow import models, settings -from airflow.jobs.base_job import BaseJob from airflow.models import DAG, DagBag, DagModel, TaskInstance as TI, clear_task_instances from airflow.models.dagrun import DagRun -from airflow.operators.bash import BashOperator from airflow.operators.dummy import DummyOperator from airflow.operators.python import ShortCircuitOperator from airflow.serialization.serialized_objects import SerializedDAG @@ -39,7 +37,7 @@ from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType from tests.models import DEFAULT_DATE -from tests.test_utils.db import clear_db_dags, clear_db_jobs, clear_db_pools, clear_db_runs +from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs class TestDagRun(unittest.TestCase): @@ -87,8 +85,7 @@ def create_dag_run( for task_id, task_state in task_states.items(): ti = dag_run.get_task_instance(task_id) ti.set_state(task_state, session) - session.commit() - session.close() + session.flush() return dag_run @@ -377,7 +374,7 @@ def on_success_callable(context): assert callback == DagCallbackRequest( full_filepath=dag_run.dag.fileloc, dag_id="test_dagrun_update_state_with_handle_callback_success", - execution_date=dag_run.execution_date, + run_id=dag_run.run_id, is_failure_callback=False, msg="success", ) @@ -412,7 +409,7 @@ def on_failure_callable(context): assert callback == DagCallbackRequest( full_filepath=dag_run.dag.fileloc, dag_id="test_dagrun_update_state_with_handle_callback_failure", - execution_date=dag_run.execution_date, + run_id=dag_run.run_id, is_failure_callback=True, msg="task_failure", ) @@ -825,40 +822,3 @@ def test_states_sets(self): ti_failed = dag_run.get_task_instance(dag_task_failed.task_id) assert ti_success.state in State.success_states assert ti_failed.state in State.failed_states - - def test_delete_dag_run_and_task_instance_does_not_raise_error(self): - clear_db_jobs() - clear_db_runs() - - job_id = 22 - dag = DAG(dag_id='test_delete_dag_run', start_date=days_ago(1)) - _ = BashOperator(task_id='task1', dag=dag, bash_command="echo hi") - - # Simulate DagRun is created by a job inherited by BaseJob with an id - # This is so that same foreign key exists on DagRun.creating_job_id & BaseJob.id - dag_run = self.create_dag_run(dag=dag, creating_job_id=job_id) - assert dag_run is not None - - session = settings.Session() - - job = BaseJob(id=job_id) - session.add(job) - - # Simulate TaskInstance is created by a job inherited by BaseJob with an id - # This is so that same foreign key exists on TaskInstance.queued_by_job_id & BaseJob.id - ti1 = dag_run.get_task_instance(task_id="task1") - ti1.queued_by_job_id = job_id - session.merge(ti1) - session.commit() - - # Test Deleting DagRun does not raise an error - session.delete(dag_run) - - # Test Deleting TaskInstance does not raise an error - ti1 = dag_run.get_task_instance(task_id="task1") - session.delete(ti1) - session.commit() - - # CleanUp - clear_db_runs() - clear_db_jobs() diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py index e2093f2415b86..47e4956fd469c 100644 --- a/tests/models/test_renderedtifields.py +++ b/tests/models/test_renderedtifields.py @@ -23,12 +23,12 @@ from unittest import mock import pytest +from sqlalchemy.orm.session import make_transient from airflow import settings from airflow.configuration import TEST_DAGS_FOLDER from airflow.models import Variable from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF -from airflow.models.taskinstance import TaskInstance as TI from airflow.operators.bash import BashOperator from airflow.utils.session import create_session from airflow.utils.timezone import datetime @@ -115,27 +115,32 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field, da Test that template_fields are rendered correctly, stored in the Database, and are correctly fetched using RTIF.get_templated_fields """ - with dag_maker("test_serialized_rendered_fields") as dag: + with dag_maker("test_serialized_rendered_fields"): task = BashOperator(task_id="test", bash_command=templated_field) - dag_maker.create_dagrun() - ti = TI(task=task, execution_date=EXECUTION_DATE) + task_2 = BashOperator(task_id="test2", bash_command=templated_field) + dr = dag_maker.create_dagrun() + + session = dag_maker.session + + ti, ti2 = dr.task_instances + ti.task = task + ti2.task = task_2 rtif = RTIF(ti=ti) + assert ti.dag_id == rtif.dag_id assert ti.task_id == rtif.task_id assert ti.execution_date == rtif.execution_date assert expected_rendered_field == rtif.rendered_fields.get("bash_command") - with create_session() as session: - session.add(rtif) - - assert {"bash_command": expected_rendered_field, "env": None} == RTIF.get_templated_fields(ti=ti) + session.add(rtif) + session.flush() + assert {"bash_command": expected_rendered_field, "env": None} == RTIF.get_templated_fields( + ti=ti, session=session + ) # Test the else part of get_templated_fields # i.e. for the TIs that are not stored in RTIF table # Fetching them will return None - task_2 = BashOperator(task_id="test2", bash_command=templated_field, dag=dag) - - ti2 = TI(task_2, EXECUTION_DATE) assert RTIF.get_templated_fields(ti=ti2) is None @pytest.mark.parametrize( @@ -159,14 +164,15 @@ def test_delete_old_records( session = settings.Session() with dag_maker("test_delete_old_records") as dag: task = BashOperator(task_id="test", bash_command="echo {{ ds }}") - dag_maker.create_dagrun() - rtif_list = [ - RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num))) - for num in range(rtif_num) - ] + rtif_list = [] + for num in range(rtif_num): + dr = dag_maker.create_dagrun(run_id=str(num), execution_date=dag.start_date + timedelta(days=num)) + ti = dr.task_instances[0] + ti.task = task + rtif_list.append(RTIF(ti)) session.add_all(rtif_list) - session.commit() + session.flush() result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() @@ -201,7 +207,11 @@ def test_write(self, dag_maker): with dag_maker("test_write"): task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}") - rtif = RTIF(TI(task=task, execution_date=EXECUTION_DATE)) + dr = dag_maker.create_dagrun() + ti = dr.task_instances[0] + ti.task = task + + rtif = RTIF(ti) rtif.write() result = ( session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields) @@ -220,8 +230,10 @@ def test_write(self, dag_maker): self.clean_db() with dag_maker("test_write"): updated_task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}") - dag_maker.create_dagrun() - rtif_updated = RTIF(TI(task=updated_task, execution_date=EXECUTION_DATE)) + dr = dag_maker.create_dagrun() + ti = dr.task_instances[0] + ti.task = updated_task + rtif_updated = RTIF(ti) rtif_updated.write() result_updated = ( @@ -248,10 +260,11 @@ def test_get_k8s_pod_yaml(self, redact, dag_maker): """ with dag_maker("test_get_k8s_pod_yaml") as dag: task = BashOperator(task_id="test", bash_command="echo hi") - dag_maker.create_dagrun() + dr = dag_maker.create_dagrun() dag.fileloc = TEST_DAGS_FOLDER + '/test_get_k8s_pod_yaml.py' - ti = TI(task=task, execution_date=EXECUTION_DATE) + ti = dr.task_instances[0] + ti.task = task render_k8s_pod_yaml = mock.patch.object( ti, 'render_k8s_pod_yaml', return_value={"I'm a": "pod"} @@ -274,6 +287,8 @@ def test_get_k8s_pod_yaml(self, redact, dag_maker): session.flush() assert expected_pod_yaml == RTIF.get_k8s_pod_yaml(ti=ti, session=session) + make_transient(ti) + # "Delete" it from the DB session.rollback() # Test the else part of get_k8s_pod_yaml @@ -290,13 +305,14 @@ def test_redact(self, redact, dag_maker): bash_command="echo {{ var.value.api_key }}", env={'foo': 'secret', 'other_api_key': 'masked based on key name'}, ) - dag_maker.create_dagrun() + dr = dag_maker.create_dagrun() redact.side_effect = [ 'val 1', 'val 2', ] - ti = TI(task=task, execution_date=EXECUTION_DATE) + ti = dr.task_instances[0] + ti.task = task rtif = RTIF(ti=ti) assert rtif.rendered_fields == { 'bash_command': 'val 1', diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py index 8734cc7f416f5..21ddf4d838d33 100644 --- a/tests/models/test_skipmixin.py +++ b/tests/models/test_skipmixin.py @@ -74,9 +74,10 @@ def test_skip_none_dagrun(self, mock_now, dag_maker): mock_now.return_value = now with dag_maker( 'dag', + session=session, ): tasks = [DummyOperator(task_id='task')] - dag_maker.create_dagrun() + dag_maker.create_dagrun(execution_date=now) SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks, session=session) session.query(TI).filter( diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index ded890182cd4e..a2f2ec5d60a24 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -29,7 +29,6 @@ import pendulum import pytest from freezegun import freeze_time -from sqlalchemy.orm.session import Session from airflow import models, settings from airflow.exceptions import ( @@ -47,6 +46,7 @@ TaskInstance as TI, TaskReschedule, Variable, + XCom, ) from airflow.models.taskinstance import load_error_file, set_error_file from airflow.operators.bash import BashOperator @@ -70,7 +70,17 @@ from tests.test_utils import db from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars -from tests.test_utils.db import clear_db_connections +from tests.test_utils.db import clear_db_connections, clear_db_runs + + +@pytest.fixture +def test_pool(): + with create_session() as session: + test_pool = Pool(pool='test_pool', slots=1) + session.add(test_pool) + session.flush() + yield test_pool + session.rollback() class CallbackWrapper: @@ -112,10 +122,10 @@ def clean_db(): def setup_method(self): self.clean_db() - with create_session() as session: - test_pool = Pool(pool='test_pool', slots=1) - session.add(test_pool) - session.commit() + + # We don't want to store any code for (test) dags created in this file + with patch.object(settings, "STORE_DAG_CODE", False): + yield def teardown_method(self): self.clean_db() @@ -167,40 +177,6 @@ def test_set_task_dates(self, dag_maker): assert op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1) assert op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9) - def test_timezone_awareness(self, dag_maker): - naive_datetime = DEFAULT_DATE.replace(tzinfo=None) - - # check ti without dag (just for bw compat) - op_no_dag = DummyOperator(task_id='op_no_dag') - ti = TI(task=op_no_dag, execution_date=naive_datetime) - - assert ti.execution_date == DEFAULT_DATE - - # check with dag without localized execution_date - with dag_maker('dag'): - op1 = DummyOperator(task_id='op_1') - dag_maker.create_dagrun() - ti = TI(task=op1, execution_date=naive_datetime) - - assert ti.execution_date == DEFAULT_DATE - - # with dag and localized execution_date - tzinfo = pendulum.timezone("Europe/Amsterdam") - execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo) - utc_date = timezone.convert_to_utc(execution_date) - ti = TI(task=op1, execution_date=execution_date) - assert ti.execution_date == utc_date - - def test_task_naive_datetime(self): - naive_datetime = DEFAULT_DATE.replace(tzinfo=None) - - op_no_dag = DummyOperator( - task_id='test_task_naive_datetime', start_date=naive_datetime, end_date=naive_datetime - ) - - assert op_no_dag.start_date.tzinfo - assert op_no_dag.end_date.tzinfo - def test_set_dag(self, dag_maker): """ Test assigning Operators to Dags, including deferred assignment @@ -271,60 +247,47 @@ def test_bitshift_compose_operators(self, dag_maker): assert op2 in op3.downstream_list @patch.object(DAG, 'get_concurrency_reached') - def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, create_dummy_dag): + def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, create_task_instance): mock_concurrency_reached.return_value = True - _, task = create_dummy_dag( + ti = create_task_instance( dag_id='test_requeue_over_dag_concurrency', task_id='test_requeue_over_dag_concurrency_op', max_active_runs=1, max_active_tasks=2, + dagrun_state=State.QUEUED, ) - - ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) - # TI.run() will sync from DB before validating deps. - with create_session() as session: - session.add(ti) - session.commit() ti.run() assert ti.state == State.NONE - def test_requeue_over_max_active_tis_per_dag(self, create_dummy_dag): - _, task = create_dummy_dag( + def test_requeue_over_max_active_tis_per_dag(self, create_task_instance): + ti = create_task_instance( dag_id='test_requeue_over_max_active_tis_per_dag', task_id='test_requeue_over_max_active_tis_per_dag_op', max_active_tis_per_dag=0, max_active_runs=1, max_active_tasks=2, + dagrun_state=State.QUEUED, ) - ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) - # TI.run() will sync from DB before validating deps. - with create_session() as session: - session.add(ti) - session.commit() ti.run() assert ti.state == State.NONE - def test_requeue_over_pool_concurrency(self, create_dummy_dag): - _, task = create_dummy_dag( + def test_requeue_over_pool_concurrency(self, create_task_instance, test_pool): + ti = create_task_instance( dag_id='test_requeue_over_pool_concurrency', task_id='test_requeue_over_pool_concurrency_op', max_active_tis_per_dag=0, max_active_runs=1, max_active_tasks=2, ) - - ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) - # TI.run() will sync from DB before validating deps. with create_session() as session: - pool = session.query(Pool).filter(Pool.pool == 'test_pool').one() - pool.slots = 0 - session.add(ti) - session.commit() - ti.run() - assert ti.state == State.NONE + test_pool.slots = 0 + session.flush() + ti.run() + assert ti.state == State.NONE + @pytest.mark.usefixtures('test_pool') def test_not_requeue_non_requeueable_task_instance(self, dag_maker): # Use BaseSensorOperator because sensor got # one additional DEP in BaseSensorOperator().deps @@ -333,8 +296,9 @@ def test_not_requeue_non_requeueable_task_instance(self, dag_maker): task_id='test_not_requeue_non_requeueable_task_instance_op', pool='test_pool', ) - dag_maker.create_dagrun() - ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + ti.task = task + ti.state = State.QUEUED with create_session() as session: session.add(ti) session.commit() @@ -358,69 +322,61 @@ def test_not_requeue_non_requeueable_task_instance(self, dag_maker): for (dep_patch, method_patch) in patch_dict.values(): dep_patch.stop() - def test_mark_non_runnable_task_as_success(self, create_dummy_dag): + def test_mark_non_runnable_task_as_success(self, create_task_instance): """ test that running task with mark_success param update task state as SUCCESS without running task despite it fails dependency checks. """ non_runnable_state = (set(State.task_states) - RUNNABLE_STATES - set(State.SUCCESS)).pop() - _, task = create_dummy_dag( + ti = create_task_instance( dag_id='test_mark_non_runnable_task_as_success', task_id='test_mark_non_runnable_task_as_success_op', + dagrun_state=non_runnable_state, ) - ti = TI(task=task, execution_date=timezone.utcnow(), state=non_runnable_state) - # TI.run() will sync from DB before validating deps. - with create_session() as session: - session.add(ti) - session.commit() ti.run(mark_success=True) assert ti.state == State.SUCCESS - def test_run_pooling_task(self, create_dummy_dag): + @pytest.mark.usefixtures('test_pool') + def test_run_pooling_task(self, create_task_instance): """ test that running a task in an existing pool update task state as SUCCESS. """ - _, task = create_dummy_dag( + ti = create_task_instance( dag_id='test_run_pooling_task', task_id='test_run_pooling_task_op', pool='test_pool', ) - ti = TI(task=task, execution_date=timezone.utcnow()) ti.run() - db.clear_db_pools() assert ti.state == State.SUCCESS + @pytest.mark.usefixtures('test_pool') def test_pool_slots_property(self): """ test that try to create a task with pool_slots less than 1 """ - def create_task_instance(): + with pytest.raises(AirflowException): dag = models.DAG(dag_id='test_run_pooling_task') - task = DummyOperator( + DummyOperator( task_id='test_run_pooling_task_op', dag=dag, pool='test_pool', pool_slots=0, ) - return TI(task=task, execution_date=timezone.utcnow()) - - with pytest.raises(AirflowException): - create_task_instance() @provide_session - def test_ti_updates_with_task(self, create_dummy_dag, session=None): + def test_ti_updates_with_task(self, create_task_instance, session=None): """ test that updating the executor_config propagates to the TaskInstance DB """ - dag, task = create_dummy_dag( + ti = create_task_instance( dag_id='test_run_pooling_task', task_id='test_run_pooling_task_op', executor_config={'foo': 'bar'}, ) - ti = TI(task=task, execution_date=timezone.utcnow()) + dag = ti.task.dag ti.run(session=session) tis = dag.get_task_instances() @@ -432,24 +388,27 @@ def test_ti_updates_with_task(self, create_dummy_dag, session=None): dag=dag, ) - ti = TI(task=task2, execution_date=timezone.utcnow()) + ti2 = TI(task=task2, run_id=ti.run_id) + session.add(ti2) + session.flush() - ti.run(session=session) - tis = dag.get_task_instances() - assert {'bar': 'baz'} == tis[1].executor_config + ti2.run(session=session) + # Ensure it's reloaded + ti2.executor_config = None + ti2.refresh_from_db(session) + assert {'bar': 'baz'} == ti2.executor_config session.rollback() - def test_run_pooling_task_with_mark_success(self, create_dummy_dag): + def test_run_pooling_task_with_mark_success(self, create_task_instance): """ test that running task in an existing pool with mark_success param update task state as SUCCESS without running task despite it fails dependency checks. """ - _, task = create_dummy_dag( + ti = create_task_instance( dag_id='test_run_pooling_task_with_mark_success', task_id='test_run_pooling_task_with_mark_success_op', ) - ti = TI(task=task, execution_date=timezone.utcnow()) ti.run(mark_success=True) assert ti.state == State.SUCCESS @@ -468,7 +427,10 @@ def raise_skip_exception(): task_id='test_run_pooling_task_with_skip', python_callable=raise_skip_exception, ) - ti = TI(task=task, execution_date=timezone.utcnow()) + + dr = dag_maker.create_dagrun(execution_date=timezone.utcnow()) + ti = dr.task_instances[0] + ti.task = task ti.run() assert State.SKIPPED == ti.state @@ -489,9 +451,9 @@ def task_function(ti): retry_delay=datetime.timedelta(seconds=2), ) - dag_maker.create_dagrun() - ti = TI(task=task, execution_date=DEFAULT_DATE) - ti.refresh_from_db() + dr = dag_maker.create_dagrun() + ti = dr.task_instances[0] + ti.task = task with pytest.raises(AirflowException): ti.run() ti.refresh_from_db() @@ -508,7 +470,6 @@ def test_retry_delay(self, dag_maker): retries=1, retry_delay=datetime.timedelta(seconds=3), ) - dag_maker.create_dagrun() def run_with_error(ti): try: @@ -516,7 +477,8 @@ def run_with_error(ti): except AirflowException: pass - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + ti.task = task assert ti.try_number == 1 # first run -- up for retry @@ -553,7 +515,8 @@ def run_with_error(ti): except AirflowException: pass - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + ti.task = task assert ti.try_number == 1 # first run -- up for retry @@ -599,8 +562,8 @@ def test_next_retry_datetime(self, dag_maker): retry_exponential_backoff=True, max_retry_delay=max_delay, ) - dag_maker.create_dagrun() - ti = TI(task=task, execution_date=DEFAULT_DATE) + ti = dag_maker.create_dagrun().task_instances[0] + ti.task = task ti.end_date = pendulum.instance(timezone.utcnow()) date = ti.next_retry_datetime() @@ -641,8 +604,8 @@ def test_next_retry_datetime_short_intervals(self, dag_maker): retry_exponential_backoff=True, max_retry_delay=max_delay, ) - dag_maker.create_dagrun() - ti = TI(task=task, execution_date=DEFAULT_DATE) + ti = dag_maker.create_dagrun().task_instances[0] + ti.task = task ti.end_date = pendulum.instance(timezone.utcnow()) date = ti.next_retry_datetime() @@ -673,12 +636,11 @@ def func(): retry_delay=datetime.timedelta(seconds=0), ) - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + ti.task = task assert ti._try_number == 0 assert ti.try_number == 1 - dag_maker.create_dagrun() - def run_ti_and_assert( run_date, expected_start_date, @@ -749,6 +711,7 @@ def run_ti_and_assert( done, fail = True, False run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0) + @pytest.mark.usefixtures('test_pool') def test_reschedule_handling_clear_reschedules(self, dag_maker): """ Test that task reschedules clearing are handled properly @@ -772,8 +735,8 @@ def func(): retry_delay=datetime.timedelta(seconds=0), pool='test_pool', ) - dag_maker.create_dagrun() - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + ti.task = task assert ti._try_number == 0 assert ti.try_number == 1 @@ -829,12 +792,13 @@ def test_depends_on_past(self, dag_maker): run_date = task.start_date + datetime.timedelta(days=5) - dag_maker.create_dagrun( + dr = dag_maker.create_dagrun( execution_date=run_date, run_type=DagRunType.SCHEDULED, ) - ti = TI(task, run_date) + ti = dr.task_instances[0] + ti.task = task # depends_on_past prevents the run task.run(start_date=run_date, end_date=run_date, ignore_first_depends_on_past=False) @@ -911,15 +875,17 @@ def test_check_task_dependencies( flag_upstream_failed: bool, expect_state: State, expect_completed: bool, - create_dummy_dag, + dag_maker, ): - dag, downstream = create_dummy_dag('test-dag', task_id='downstream', trigger_rule=trigger_rule) - for i in range(5): - task = DummyOperator(task_id=f'runme_{i}', dag=dag) - task.set_downstream(downstream) + with dag_maker() as dag: + downstream = DummyOperator(task_id="downstream", trigger_rule=trigger_rule) + for i in range(5): + task = DummyOperator(task_id=f'runme_{i}', dag=dag) + task.set_downstream(downstream) run_date = task.start_date + datetime.timedelta(days=5) - ti = TI(downstream, run_date) + ti = dag_maker.create_dagrun(execution_date=run_date).get_task_instance(downstream.task_id) + ti.task = downstream dep_results = TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=successes, @@ -934,9 +900,8 @@ def test_check_task_dependencies( assert completed == expect_completed assert ti.state == expect_state - def test_respects_prev_dagrun_dep(self, create_dummy_dag): - _, task = create_dummy_dag(dag_id='test_dag') - ti = TI(task, DEFAULT_DATE) + def test_respects_prev_dagrun_dep(self, create_task_instance): + ti = create_task_instance() failing_status = [TIDepStatus('test fail status name', False, 'test fail reason')] passing_status = [TIDepStatus('test pass status name', True, 'test passing reason')] with patch( @@ -958,38 +923,42 @@ def test_respects_prev_dagrun_dep(self, create_dummy_dag): (State.NONE, False), ], ) - def test_are_dependents_done(self, downstream_ti_state, expected_are_dependents_done, create_dummy_dag): - dag, task = create_dummy_dag() + @provide_session + def test_are_dependents_done( + self, downstream_ti_state, expected_are_dependents_done, create_task_instance, session=None + ): + ti = create_task_instance(session=session) + dag = ti.task.dag downstream_task = DummyOperator(task_id='downstream_task', dag=dag) - task >> downstream_task + ti.task >> downstream_task - ti = TI(task, DEFAULT_DATE) - downstream_ti = TI(downstream_task, DEFAULT_DATE) + downstream_ti = TI(downstream_task, run_id=ti.run_id) - downstream_ti.set_state(downstream_ti_state) - assert ti.are_dependents_done() == expected_are_dependents_done + downstream_ti.set_state(downstream_ti_state, session) + session.flush() + assert ti.are_dependents_done(session) == expected_are_dependents_done - def test_xcom_pull(self, create_dummy_dag): + def test_xcom_pull(self, create_task_instance): """ Test xcom_pull, using different filtering methods. """ - dag, task1 = create_dummy_dag( + ti1 = create_task_instance( dag_id='test_xcom', task_id='test_xcom_1', - schedule_interval='@monthly', start_date=timezone.datetime(2016, 6, 1, 0, 0, 0), ) - exec_date = DEFAULT_DATE - # Push a value - ti1 = TI(task=task1, execution_date=exec_date) ti1.xcom_push(key='foo', value='bar') # Push another value with the same key (but by a different task) - task2 = DummyOperator(task_id='test_xcom_2', dag=dag) - ti2 = TI(task=task2, execution_date=exec_date) - ti2.xcom_push(key='foo', value='baz') + XCom.set( + key='foo', + value='baz', + task_id='test_xcom_2', + dag_id=ti1.dag_id, + execution_date=ti1.execution_date, + ) # Pull with no arguments result = ti1.xcom_pull() @@ -1007,21 +976,19 @@ def test_xcom_pull(self, create_dummy_dag): result = ti1.xcom_pull(task_ids=['test_xcom_1', 'test_xcom_2'], key='foo') assert result == ['bar', 'baz'] - def test_xcom_pull_after_success(self, create_dummy_dag): + def test_xcom_pull_after_success(self, create_task_instance): """ tests xcom set/clear relative to a task in a 'success' rerun scenario """ key = 'xcom_key' value = 'xcom_value' - _, task = create_dummy_dag( + ti = create_task_instance( dag_id='test_xcom', schedule_interval='@monthly', task_id='test_xcom', pool='test_xcom', ) - exec_date = DEFAULT_DATE - ti = TI(task=task, execution_date=exec_date) ti.run(mark_success=True) ti.xcom_push(key=key, value=value) @@ -1039,7 +1006,7 @@ def test_xcom_pull_after_success(self, create_dummy_dag): ti.run(ignore_all_deps=True) assert ti.xcom_pull(task_ids='test_xcom', key=key) is None - def test_xcom_pull_different_execution_date(self, create_dummy_dag): + def test_xcom_pull_different_execution_date(self, create_task_instance): """ tests xcom fetch behavior with different execution dates, using both xcom_pull with "include_prior_dates" and without @@ -1047,21 +1014,21 @@ def test_xcom_pull_different_execution_date(self, create_dummy_dag): key = 'xcom_key' value = 'xcom_value' - dag, task = create_dummy_dag( + ti = create_task_instance( dag_id='test_xcom', schedule_interval='@monthly', task_id='test_xcom', pool='test_xcom', ) - exec_date = DEFAULT_DATE - ti = TI(task=task, execution_date=exec_date) + exec_date = ti.dag_run.execution_date ti.run(mark_success=True) ti.xcom_push(key=key, value=value) assert ti.xcom_pull(task_ids='test_xcom', key=key) == value ti.run() exec_date += datetime.timedelta(days=1) - ti = TI(task=task, execution_date=exec_date) + dr = ti.task.dag.create_dagrun(run_id="test2", execution_date=exec_date, state=None) + ti = TI(task=ti.task, run_id=dr.run_id) ti.run() # We have set a new execution date (and did not pass in # 'include_prior_dates'which means this task should now have a cleared @@ -1084,8 +1051,8 @@ def test_xcom_push_flag(self, dag_maker): python_callable=lambda: value, do_xcom_push=False, ) - ti = TI(task=task, execution_date=DEFAULT_DATE) - dag_maker.create_dagrun() + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + ti.task = task ti.run() assert ti.xcom_pull(task_ids=task_id, key=models.XCOM_RETURN_KEY) is None @@ -1108,33 +1075,31 @@ def post_execute(self, context, result=None): task_id='test_operator', python_callable=lambda: 'error', ) - ti = TI(task=task, execution_date=DEFAULT_DATE) - dag_maker.create_dagrun() + ti = dag_maker.create_dagrun(execution_date=DEFAULT_DATE).task_instances[0] + ti.task = task with pytest.raises(TestError): ti.run() - def test_check_and_change_state_before_execution(self, create_dummy_dag): - _, task = create_dummy_dag(dag_id='test_check_and_change_state_before_execution') - ti = TI(task=task, execution_date=DEFAULT_DATE) + def test_check_and_change_state_before_execution(self, create_task_instance): + ti = create_task_instance(dag_id='test_check_and_change_state_before_execution') assert ti._try_number == 0 assert ti.check_and_change_state_before_execution() # State should be running, and try_number column should be incremented assert ti.state == State.RUNNING assert ti._try_number == 1 - def test_check_and_change_state_before_execution_dep_not_met(self, create_dummy_dag): - dag, task = create_dummy_dag(dag_id='test_check_and_change_state_before_execution') - task2 = DummyOperator(task_id='task2', dag=dag, start_date=DEFAULT_DATE) - task >> task2 - ti = TI(task=task2, execution_date=timezone.utcnow()) + def test_check_and_change_state_before_execution_dep_not_met(self, create_task_instance): + ti = create_task_instance(dag_id='test_check_and_change_state_before_execution') + task2 = DummyOperator(task_id='task2', dag=ti.task.dag, start_date=DEFAULT_DATE) + ti.task >> task2 + ti = TI(task=task2, run_id=ti.run_id) assert not ti.check_and_change_state_before_execution() - def test_try_number(self, create_dummy_dag): + def test_try_number(self, create_task_instance): """ Test the try_number accessor behaves in various running states """ - _, task = create_dummy_dag(dag_id='test_check_and_change_state_before_execution') - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = create_task_instance(dag_id='test_check_and_change_state_before_execution') assert 1 == ti.try_number ti.try_number = 2 ti.state = State.RUNNING @@ -1142,20 +1107,33 @@ def test_try_number(self, create_dummy_dag): ti.state = State.SUCCESS assert 3 == ti.try_number - def test_get_num_running_task_instances(self, create_dummy_dag): + def test_get_num_running_task_instances(self, create_task_instance): session = settings.Session() - _, task = create_dummy_dag(dag_id='test_get_num_running_task_instances', task_id='task1') - _, task2 = create_dummy_dag(dag_id='test_get_num_running_task_instances_dummy', task_id='task2') - ti1 = TI(task=task, execution_date=DEFAULT_DATE) - ti2 = TI(task=task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) - ti3 = TI(task=task2, execution_date=DEFAULT_DATE) + ti1 = create_task_instance( + dag_id='test_get_num_running_task_instances', task_id='task1', session=session + ) + + dr = ti1.task.dag.create_dagrun( + execution_date=DEFAULT_DATE + datetime.timedelta(days=1), + state=None, + run_id='2', + session=session, + ) + assert ti1 in session + ti2 = dr.task_instances[0] + ti2.task = ti1.task + + ti3 = create_task_instance( + dag_id='test_get_num_running_task_instances_dummy', task_id='task2', session=session + ) + assert ti3 in session + assert ti1 in session + ti1.state = State.RUNNING ti2.state = State.QUEUED ti3.state = State.RUNNING - session.merge(ti1) - session.merge(ti2) - session.merge(ti3) + assert ti3 in session session.commit() assert 1 == ti1.get_num_running_task_instances(session=session) @@ -1174,9 +1152,8 @@ def test_get_num_running_task_instances(self, create_dummy_dag): # self.assertEqual(d['task_id'][0], 'op') # self.assertEqual(pendulum.parse(d['execution_date'][0]), now) - def test_log_url(self, create_dummy_dag): - _, task = create_dummy_dag('dag', task_id='op') - ti = TI(task=task, execution_date=datetime.datetime(2018, 1, 1)) + def test_log_url(self, create_task_instance): + ti = create_task_instance(dag_id='dag', task_id='op', execution_date=timezone.datetime(2018, 1, 1)) expected_url = ( 'http://localhost:8080/log?' @@ -1186,10 +1163,9 @@ def test_log_url(self, create_dummy_dag): ) assert ti.log_url == expected_url - def test_mark_success_url(self, create_dummy_dag): + def test_mark_success_url(self, create_task_instance): now = pendulum.now('Europe/Brussels') - _, task = create_dummy_dag('dag', task_id='op') - ti = TI(task=task, execution_date=now) + ti = create_task_instance(dag_id='dag', task_id='op', execution_date=now) query = urllib.parse.parse_qs( urllib.parse.urlparse(ti.mark_success_url).query, keep_blank_values=True, strict_parsing=True ) @@ -1197,10 +1173,9 @@ def test_mark_success_url(self, create_dummy_dag): assert query['task_id'][0] == 'op' assert pendulum.parse(query['execution_date'][0]) == now - def test_overwrite_params_with_dag_run_conf(self): - task = DummyOperator(task_id='op') - ti = TI(task=task, execution_date=datetime.datetime.now()) - dag_run = DagRun() + def test_overwrite_params_with_dag_run_conf(self, create_task_instance): + ti = create_task_instance() + dag_run = ti.dag_run dag_run.conf = {"override": True} params = {"override": False} @@ -1208,20 +1183,18 @@ def test_overwrite_params_with_dag_run_conf(self): assert params["override"] is True - def test_overwrite_params_with_dag_run_none(self): - task = DummyOperator(task_id='op') - ti = TI(task=task, execution_date=datetime.datetime.now()) + def test_overwrite_params_with_dag_run_none(self, create_task_instance): + ti = create_task_instance() params = {"override": False} ti.overwrite_params_with_dag_run_conf(params, None) assert params["override"] is False - def test_overwrite_params_with_dag_run_conf_none(self): - task = DummyOperator(task_id='op') - ti = TI(task=task, execution_date=datetime.datetime.now()) + def test_overwrite_params_with_dag_run_conf_none(self, create_task_instance): + ti = create_task_instance() params = {"override": False} - dag_run = DagRun() + dag_run = ti.dag_run ti.overwrite_params_with_dag_run_conf(params, dag_run) @@ -1231,8 +1204,8 @@ def test_overwrite_params_with_dag_run_conf_none(self): def test_email_alert(self, mock_send_email, dag_maker): with dag_maker(dag_id='test_failure_email'): task = BashOperator(task_id='test_email_alert', bash_command='exit 1', email='to') - dag_maker.create_dagrun() - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + ti.task = task try: ti.run() @@ -1259,8 +1232,8 @@ def test_email_alert_with_config(self, mock_send_email, dag_maker): bash_command='exit 1', email='to', ) - dag_maker.create_dagrun() - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + ti.task = task opener = mock_open(read_data='template: {{ti.task_id}}') with patch('airflow.models.taskinstance.open', opener, create=True): @@ -1276,10 +1249,7 @@ def test_email_alert_with_config(self, mock_send_email, dag_maker): def test_set_duration(self): task = DummyOperator(task_id='op', email='test@test.test') - ti = TI( - task=task, - execution_date=datetime.datetime.now(), - ) + ti = TI(task=task) ti.start_date = datetime.datetime(2018, 10, 1, 1) ti.end_date = datetime.datetime(2018, 10, 1, 2) ti.set_duration() @@ -1287,20 +1257,19 @@ def test_set_duration(self): def test_set_duration_empty_dates(self): task = DummyOperator(task_id='op', email='test@test.test') - ti = TI(task=task, execution_date=datetime.datetime.now()) + ti = TI(task=task) ti.set_duration() assert ti.duration is None - def test_success_callback_no_race_condition(self, create_dummy_dag): + def test_success_callback_no_race_condition(self, create_task_instance): callback_wrapper = CallbackWrapper() - _, task = create_dummy_dag( - 'test_success_callback_no_race_condition', + ti = create_task_instance( on_success_callback=callback_wrapper.success_handler, end_date=DEFAULT_DATE + datetime.timedelta(days=10), + execution_date=timezone.utcnow(), + state=State.RUNNING, ) - ti = TI(task=task, execution_date=datetime.datetime.now()) - ti.state = State.RUNNING session = settings.Session() session.merge(ti) session.commit() @@ -1321,30 +1290,29 @@ def _test_previous_dates_setup( with dag_maker(dag_id=dag_id, schedule_interval=schedule_interval, catchup=catchup): task = DummyOperator(task_id='task') - def get_test_ti(session, execution_date: pendulum.DateTime, state: str) -> TI: - dag_maker.create_dagrun( + def get_test_ti(execution_date: pendulum.DateTime, state: str) -> TI: + dr = dag_maker.create_dagrun( + run_id=f'test__{execution_date.isoformat()}', run_type=DagRunType.SCHEDULED, state=state, execution_date=execution_date, start_date=pendulum.now('UTC'), - session=session, ) - ti = TI(task=task, execution_date=execution_date) - ti.set_state(state=State.SUCCESS, session=session) + ti = dr.task_instances[0] + ti.task = task + ti.set_state(state=State.SUCCESS, session=dag_maker.session) return ti - with create_session() as session: # type: Session + date = cast(pendulum.DateTime, pendulum.parse('2019-01-01T00:00:00+00:00')) - date = cast(pendulum.DateTime, pendulum.parse('2019-01-01T00:00:00+00:00')) + ret = [] - ret = [] + for idx, state in enumerate(scenario): + new_date = date.add(days=idx) + ti = get_test_ti(new_date, state) + ret.append(ti) - for idx, state in enumerate(scenario): - new_date = date.add(days=idx) - ti = get_test_ti(session, new_date, state) - ret.append(ti) - - return ret + return ret _prev_dates_param_list = [ pytest.param('0 0 * * * ', True, id='cron/catchup'), @@ -1364,9 +1332,9 @@ def test_previous_ti(self, schedule_interval, catchup, dag_maker) -> None: assert ti_list[0].get_previous_ti() is None - assert ti_list[2].get_previous_ti().execution_date == ti_list[1].execution_date + assert ti_list[2].get_previous_ti().run_id == ti_list[1].run_id - assert ti_list[2].get_previous_ti().execution_date != ti_list[0].execution_date + assert ti_list[2].get_previous_ti().run_id != ti_list[0].run_id @pytest.mark.parametrize("schedule_interval, catchup", _prev_dates_param_list) def test_previous_ti_success(self, schedule_interval, catchup, dag_maker) -> None: @@ -1378,9 +1346,9 @@ def test_previous_ti_success(self, schedule_interval, catchup, dag_maker) -> Non assert ti_list[0].get_previous_ti(state=State.SUCCESS) is None assert ti_list[1].get_previous_ti(state=State.SUCCESS) is None - assert ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date == ti_list[1].execution_date + assert ti_list[3].get_previous_ti(state=State.SUCCESS).run_id == ti_list[1].run_id - assert ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date != ti_list[2].execution_date + assert ti_list[3].get_previous_ti(state=State.SUCCESS).run_id != ti_list[2].run_id @pytest.mark.parametrize("schedule_interval, catchup", _prev_dates_param_list) def test_previous_execution_date_success(self, schedule_interval, catchup, dag_maker) -> None: @@ -1388,6 +1356,9 @@ def test_previous_execution_date_success(self, schedule_interval, catchup, dag_m scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS] ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario, dag_maker) + # vivify + for ti in ti_list: + ti.execution_date assert ti_list[0].get_previous_execution_date(state=State.SUCCESS) is None assert ti_list[1].get_previous_execution_date(state=State.SUCCESS) is None @@ -1439,23 +1410,13 @@ def test_get_previous_start_date_none(self, dag_maker): assert ti_2.get_previous_start_date() == ti_1.start_date assert ti_1.start_date is None - def test_pendulum_template_dates(self, create_dummy_dag): - dag, task = create_dummy_dag( + def test_pendulum_template_dates(self, create_task_instance): + ti = create_task_instance( dag_id='test_pendulum_template_dates', task_id='test_pendulum_template_dates_task', schedule_interval='0 12 * * *', ) - execution_date = timezone.utcnow() - - dag.create_dagrun( - execution_date=execution_date, - state=State.RUNNING, - run_type=DagRunType.MANUAL, - ) - - ti = TI(task=task, execution_date=execution_date) - template_context = ti.get_template_context() assert isinstance(template_context["data_interval_start"], pendulum.DateTime) @@ -1474,7 +1435,7 @@ def test_pendulum_template_dates(self, create_dummy_dag): ('{{ conn.a_connection.extra_dejson.extra__asana__workspace }}', 'extra1'), ], ) - def test_template_with_connection(self, content, expected_output, create_dummy_dag): + def test_template_with_connection(self, content, expected_output, create_task_instance): """ Test the availability of variables in templates """ @@ -1496,11 +1457,10 @@ def test_template_with_connection(self, content, expected_output, create_dummy_d session, ) - _, task = create_dummy_dag() + ti = create_task_instance() - ti = TI(task=task, execution_date=DEFAULT_DATE) context = ti.get_template_context() - result = task.render_template(content, context) + result = ti.task.render_template(content, context) assert result == expected_output @pytest.mark.parametrize( @@ -1512,29 +1472,25 @@ def test_template_with_connection(self, content, expected_output, create_dummy_d ('{{ var.value.get("missing_variable", "fallback") }}', 'fallback'), ], ) - def test_template_with_variable(self, content, expected_output, create_dummy_dag): + def test_template_with_variable(self, content, expected_output, create_task_instance): """ Test the availability of variables in templates """ Variable.set('a_variable', 'a test value') - _, task = create_dummy_dag() - - ti = TI(task=task, execution_date=DEFAULT_DATE) + ti = create_task_instance() context = ti.get_template_context() - result = task.render_template(content, context) + result = ti.task.render_template(content, context) assert result == expected_output - def test_template_with_variable_missing(self, create_dummy_dag): + def test_template_with_variable_missing(self, create_task_instance): """ Test the availability of variables in templates """ - _, task = create_dummy_dag() - - ti = TI(task=task, execution_date=DEFAULT_DATE) + ti = create_task_instance() context = ti.get_template_context() with pytest.raises(KeyError): - task.render_template('{{ var.value.get("missing_variable") }}', context) + ti.task.render_template('{{ var.value.get("missing_variable") }}', context) @pytest.mark.parametrize( "content, expected_output", @@ -1546,28 +1502,24 @@ def test_template_with_variable_missing(self, create_dummy_dag): ('{{ var.json.get("missing_variable", {"a": {"test": "fallback"}})["a"]["test"] }}', 'fallback'), ], ) - def test_template_with_json_variable(self, content, expected_output, create_dummy_dag): + def test_template_with_json_variable(self, content, expected_output, create_task_instance): """ Test the availability of variables in templates """ Variable.set('a_variable', {'a': {'test': 'value'}}, serialize_json=True) - _, task = create_dummy_dag() - - ti = TI(task=task, execution_date=DEFAULT_DATE) + ti = create_task_instance() context = ti.get_template_context() - result = task.render_template(content, context) + result = ti.task.render_template(content, context) assert result == expected_output - def test_template_with_json_variable_missing(self, create_dummy_dag): - _, task = create_dummy_dag() - - ti = TI(task=task, execution_date=DEFAULT_DATE) + def test_template_with_json_variable_missing(self, create_task_instance): + ti = create_task_instance() context = ti.get_template_context() with pytest.raises(KeyError): - task.render_template('{{ var.json.get("missing_variable") }}', context) + ti.task.render_template('{{ var.json.get("missing_variable") }}', context) - def test_execute_callback(self, create_dummy_dag): + def test_execute_callback(self, create_task_instance): called = False def on_execute_callable(context): @@ -1575,14 +1527,12 @@ def on_execute_callable(context): called = True assert context['dag_run'].dag_id == 'test_dagrun_execute_callback' - _, task = create_dummy_dag( - 'test_execute_callback', + ti = create_task_instance( + dag_id='test_execute_callback', on_execute_callback=on_execute_callable, - end_date=DEFAULT_DATE + datetime.timedelta(days=10), + state=State.RUNNING, ) - ti = TI(task=task, execution_date=datetime.datetime.now()) - ti.state = State.RUNNING session = settings.Session() session.merge(ti) @@ -1601,7 +1551,9 @@ def on_execute_callable(context): (State.FAILED, "Error when executing on_failure_callback"), ], ) - def test_finished_callbacks_handle_and_log_exception(self, finished_state, expected_message, dag_maker): + def test_finished_callbacks_handle_and_log_exception( + self, finished_state, expected_message, create_task_instance + ): called = completed = False def on_finish_callable(context): @@ -1610,28 +1562,24 @@ def on_finish_callable(context): raise KeyError completed = True - with dag_maker( - 'test_success_callback_handles_exception', + ti = create_task_instance( end_date=DEFAULT_DATE + datetime.timedelta(days=10), - ): - task = DummyOperator( - task_id='op', - on_success_callback=on_finish_callable, - on_retry_callback=on_finish_callable, - on_failure_callback=on_finish_callable, - ) - dag_maker.create_dagrun() - ti = TI(task=task, execution_date=datetime.datetime.now()) + on_success_callback=on_finish_callable, + on_retry_callback=on_finish_callable, + on_failure_callback=on_finish_callable, + state=finished_state, + ) ti._log = mock.Mock() - ti.state = finished_state ti._run_finished_callback() assert called assert not completed ti.log.exception.assert_called_once_with(expected_message) - def test_handle_failure(self, create_dummy_dag): + @provide_session + def test_handle_failure(self, create_dummy_dag, session=None): start_date = timezone.datetime(2016, 6, 1) + clear_db_runs() mock_on_failure_1 = mock.MagicMock() mock_on_retry_1 = mock.MagicMock() @@ -1642,8 +1590,12 @@ def test_handle_failure(self, create_dummy_dag): task_id="test_handle_failure_on_failure", on_failure_callback=mock_on_failure_1, on_retry_callback=mock_on_retry_1, + session=session, ) - ti1 = TI(task=task1, execution_date=start_date) + dr = dag.create_dagrun(run_id="test2", execution_date=timezone.utcnow(), state=None, session=session) + + ti1 = dr.get_task_instance(task1.task_id, session=session) + ti1.task = task1 ti1.state = State.FAILED ti1.handle_failure("test failure handling") ti1._run_finished_callback() @@ -1661,8 +1613,10 @@ def test_handle_failure(self, create_dummy_dag): retries=1, dag=dag, ) - ti2 = TI(task=task2, execution_date=start_date) + ti2 = TI(task=task2, run_id=dr.run_id) ti2.state = State.FAILED + session.add(ti2) + session.flush() ti2.handle_failure("test retry handling") ti2._run_finished_callback() @@ -1681,7 +1635,9 @@ def test_handle_failure(self, create_dummy_dag): retries=1, dag=dag, ) - ti3 = TI(task=task3, execution_date=start_date) + ti3 = TI(task=task3, run_id=dr.run_id) + session.add(ti3) + session.flush() ti3.state = State.FAILED ti3.handle_failure("test force_fail handling", force_fail=True) ti3._run_finished_callback() @@ -1700,7 +1656,8 @@ def fail(): python_callable=fail, retries=1, ) - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + ti.task = task try: ti.run() except AirflowFailException: @@ -1717,7 +1674,8 @@ def fail(): python_callable=fail, retries=1, ) - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + ti.task = task try: ti.run() except AirflowException: @@ -1737,11 +1695,11 @@ def test_echo_env_variables(self, dag_maker): end_date=DEFAULT_DATE + datetime.timedelta(days=10), ): op = PythonOperator(task_id='hive_in_python_op', python_callable=self._env_var_check_callback) - dag_maker.create_dagrun( + dr = dag_maker.create_dagrun( run_type=DagRunType.MANUAL, external_trigger=False, ) - ti = TI(task=op, execution_date=DEFAULT_DATE) + ti = TI(task=op, run_id=dr.run_id) ti.state = State.RUNNING session = settings.Session() session.merge(ti) @@ -1751,38 +1709,33 @@ def test_echo_env_variables(self, dag_maker): assert ti.state == State.SUCCESS @patch.object(Stats, 'incr') - def test_task_stats(self, stats_mock, create_dummy_dag): - dag, op = create_dummy_dag( - 'test_task_start_end_stats', + def test_task_stats(self, stats_mock, create_task_instance): + ti = create_task_instance( + dag_id='test_task_start_end_stats', end_date=DEFAULT_DATE + datetime.timedelta(days=10), + state=State.RUNNING, ) + stats_mock.reset_mock() - ti = TI(task=op, execution_date=DEFAULT_DATE) - ti.state = State.RUNNING session = settings.Session() session.merge(ti) session.commit() ti._run_raw_task() ti.refresh_from_db() - stats_mock.assert_called_with(f'ti.finish.{dag.dag_id}.{op.task_id}.{ti.state}') - assert call(f'ti.start.{dag.dag_id}.{op.task_id}') in stats_mock.mock_calls - assert stats_mock.call_count == 5 + stats_mock.assert_called_with(f'ti.finish.{ti.dag_id}.{ti.task_id}.{ti.state}') + assert call(f'ti.start.{ti.dag_id}.{ti.task_id}') in stats_mock.mock_calls + assert stats_mock.call_count == 4 - def test_command_as_list(self, dag_maker): - with dag_maker( - 'test_dag', - end_date=DEFAULT_DATE + datetime.timedelta(days=10), - ) as dag: - op = DummyOperator(task_id='dummy_op', dag=dag) - dag.fileloc = os.path.join(TEST_DAGS_FOLDER, 'x.py') - ti = TI(task=op, execution_date=DEFAULT_DATE) + def test_command_as_list(self, create_task_instance): + ti = create_task_instance() + ti.task.dag.fileloc = os.path.join(TEST_DAGS_FOLDER, 'x.py') assert ti.command_as_list() == [ 'airflow', 'tasks', 'run', - dag.dag_id, - op.task_id, - DEFAULT_DATE.isoformat(), + ti.dag_id, + ti.task_id, + ti.run_id, '--subdir', 'DAGS_FOLDER/x.py', ] @@ -1811,22 +1764,23 @@ def test_generate_command_specific_param(self): ) assert assert_command == generate_command - def test_get_rendered_template_fields(self, dag_maker): + @provide_session + def test_get_rendered_template_fields(self, dag_maker, session=None): - with dag_maker('test-dag') as dag: + with dag_maker('test-dag', session=session) as dag: task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}") dag.fileloc = TEST_DAGS_FOLDER + '/test_get_k8s_pod_yaml.py' + ti = dag_maker.create_dagrun().task_instances[0] + ti.task = task - ti = TI(task=task, execution_date=DEFAULT_DATE) - - with create_session() as session: - session.add(RenderedTaskInstanceFields(ti)) + session.add(RenderedTaskInstanceFields(ti)) + session.flush() # Create new TI for the same Task new_task = BashOperator(task_id='op12', bash_command="{{ task.task_id }}", dag=dag) - new_ti = TI(task=new_task, execution_date=DEFAULT_DATE) - new_ti.get_rendered_template_fields() + new_ti = TI(task=new_task, run_id=ti.run_id) + new_ti.get_rendered_template_fields(session=session) assert "op1" == ti.task.bash_command @@ -1836,17 +1790,18 @@ def test_get_rendered_template_fields(self, dag_maker): @mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"}) @mock.patch("airflow.settings.pod_mutation_hook") - def test_render_k8s_pod_yaml(self, pod_mutation_hook, dag_maker): - with dag_maker('test_get_rendered_k8s_spec'): - task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}") - dr = dag_maker.create_dagrun(run_id='test_run_id') - ti = dr.get_task_instance(task.task_id) - ti.task = task + def test_render_k8s_pod_yaml(self, pod_mutation_hook, create_task_instance): + ti = create_task_instance( + dag_id='test_render_k8s_pod_yaml', + run_id='test_run_id', + task_id='op1', + execution_date=DEFAULT_DATE, + ) expected_pod_spec = { 'metadata': { 'annotations': { - 'dag_id': 'test_get_rendered_k8s_spec', + 'dag_id': 'test_render_k8s_pod_yaml', 'execution_date': '2016-01-01T00:00:00+00:00', 'task_id': 'op1', 'try_number': '1', @@ -1854,7 +1809,7 @@ def test_render_k8s_pod_yaml(self, pod_mutation_hook, dag_maker): 'labels': { 'airflow-worker': 'worker-config', 'airflow_version': version, - 'dag_id': 'test_get_rendered_k8s_spec', + 'dag_id': 'test_render_k8s_pod_yaml', 'execution_date': '2016-01-01T00_00_00_plus_00_00', 'kubernetes_executor': 'True', 'task_id': 'op1', @@ -1870,7 +1825,7 @@ def test_render_k8s_pod_yaml(self, pod_mutation_hook, dag_maker): 'airflow', 'tasks', 'run', - 'test_get_rendered_k8s_spec', + 'test_render_k8s_pod_yaml', 'op1', 'test_run_id', '--subdir', @@ -1889,12 +1844,9 @@ def test_render_k8s_pod_yaml(self, pod_mutation_hook, dag_maker): @mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"}) @mock.patch.object(RenderedTaskInstanceFields, 'get_k8s_pod_yaml') - def test_get_rendered_k8s_spec(self, rtif_get_k8s_pod_yaml, dag_maker): + def test_get_rendered_k8s_spec(self, rtif_get_k8s_pod_yaml, create_task_instance): # Create new TI for the same Task - with dag_maker('test_get_rendered_k8s_spec'): - task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}") - - ti = TI(task=task, execution_date=DEFAULT_DATE) + ti = create_task_instance() patcher = mock.patch.object(ti, 'render_k8s_pod_yaml', autospec=True) @@ -1917,10 +1869,9 @@ def test_get_rendered_k8s_spec(self, rtif_get_k8s_pod_yaml, dag_maker): render_k8s_pod_yaml.assert_called_once() - def test_set_state_up_for_retry(self, create_dummy_dag): - dag, op1 = create_dummy_dag('dag') + def test_set_state_up_for_retry(self, create_task_instance): + ti = create_task_instance(state=State.RUNNING) - ti = TI(task=op1, execution_date=timezone.utcnow(), state=State.RUNNING) start_date = timezone.utcnow() ti.start_date = start_date @@ -1930,13 +1881,13 @@ def test_set_state_up_for_retry(self, create_dummy_dag): assert ti.start_date < ti.end_date assert ti.duration > 0 - def test_refresh_from_db(self): + def test_refresh_from_db(self, create_task_instance): run_date = timezone.utcnow() expected_values = { "task_id": "test_refresh_from_db_task", "dag_id": "test_refresh_from_db_dag", - "execution_date": run_date, + "run_id": "test", "start_date": run_date + datetime.timedelta(days=1), "end_date": run_date + datetime.timedelta(days=1, seconds=1, milliseconds=234), "duration": 1.234, @@ -1968,8 +1919,7 @@ def test_refresh_from_db(self): "This prevents refresh_from_db() from missing a field." ) - operator = DummyOperator(task_id=expected_values['task_id']) - ti = TI(task=operator, execution_date=expected_values['execution_date']) + ti = create_task_instance(task_id=expected_values['task_id'], dag_id=expected_values['dag_id']) for key, expected_value in expected_values.items(): setattr(ti, key, expected_value) with create_session() as session: @@ -1980,7 +1930,7 @@ def test_refresh_from_db(self): mock_task.task_id = expected_values["task_id"] mock_task.dag_id = expected_values["dag_id"] - ti = TI(task=mock_task, execution_date=run_date) + ti = TI(task=mock_task, run_id="test") ti.refresh_from_db() for key, expected_value in expected_values.items(): assert hasattr(ti, key), f"Key {key} is missing in the TaskInstance." @@ -1988,6 +1938,21 @@ def test_refresh_from_db(self): getattr(ti, key) == expected_value ), f"Key: {key} had different values. Make sure it loads it in the refresh refresh_from_db()" + def test_operator_field_with_serialization(self, create_task_instance): + + ti = create_task_instance() + assert ti.task.task_type == 'DummyOperator' + + # Verify that ti.operator field renders correctly "without" Serialization + assert ti.operator == "DummyOperator" + + serialized_op = SerializedBaseOperator.serialize_operator(ti.task) + deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op) + assert deserialized_op.task_type == 'DummyOperator' + # Verify that ti.operator field renders correctly "with" Serialization + ser_ti = TI(task=deserialized_op, run_id=None) + assert ser_ti.operator == "DummyOperator" + @pytest.mark.parametrize("pool_override", [None, "test_pool2"]) def test_refresh_from_task(pool_override): @@ -2001,7 +1966,7 @@ def test_refresh_from_task(pool_override): retries=30, executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}}, ) - ti = TI(task, execution_date=pendulum.datetime(2020, 1, 1)) + ti = TI(task, run_id=None) ti.refresh_from_task(task, pool_override=pool_override) assert ti.queue == task.queue @@ -2044,17 +2009,16 @@ def teardown_method(self) -> None: [ # Expected queries, mark_success (12, False), - (7, True), + (5, True), ], ) - def test_execute_queries_count(self, expected_query_count, mark_success, create_dummy_dag): - _, task = create_dummy_dag() - with create_session() as session: - - ti = TI(task=task, execution_date=datetime.datetime.now()) - ti.state = State.RUNNING + @provide_session + def test_execute_queries_count( + self, expected_query_count, mark_success, create_task_instance, session=None + ): + ti = create_task_instance(session=session, state=State.RUNNING) + assert ti.dag_run - session.merge(ti) # an extra query is fired in RenderedTaskInstanceFields.delete_old_records # for other DBs. delete_old_records is called only when mark_success is False expected_query_count_based_on_db = ( @@ -2063,38 +2027,24 @@ def test_execute_queries_count(self, expected_query_count, mark_success, create_ else expected_query_count ) + session.flush() + with assert_queries_count(expected_query_count_based_on_db): - ti._run_raw_task(mark_success=mark_success) + ti._run_raw_task(mark_success=mark_success, session=session) - def test_execute_queries_count_store_serialized(self, create_dummy_dag): - _, task = create_dummy_dag() - with create_session() as session: - ti = TI(task=task, execution_date=datetime.datetime.now()) - ti.state = State.RUNNING + @provide_session + def test_execute_queries_count_store_serialized(self, create_task_instance, session=None): + ti = create_task_instance(session=session, state=State.RUNNING) + assert ti.dag_run - session.merge(ti) # an extra query is fired in RenderedTaskInstanceFields.delete_old_records # for other DBs - expected_query_count_based_on_db = 13 if session.bind.dialect.name == "mssql" else 12 - - with assert_queries_count(expected_query_count_based_on_db): - ti._run_raw_task() + expected_query_count_based_on_db = 5 - def test_operator_field_with_serialization(self, create_dummy_dag): + session.flush() - _, task = create_dummy_dag() - assert task.task_type == 'DummyOperator' - - # Verify that ti.operator field renders correctly "without" Serialization - ti = TI(task=task, execution_date=datetime.datetime.now()) - assert ti.operator == "DummyOperator" - - serialized_op = SerializedBaseOperator.serialize_operator(task) - deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op) - assert deserialized_op.task_type == 'DummyOperator' - # Verify that ti.operator field renders correctly "with" Serialization - ser_ti = TI(task=deserialized_op, execution_date=datetime.datetime.now()) - assert ser_ti.operator == "DummyOperator" + with assert_queries_count(expected_query_count_based_on_db): + ti._run_raw_task(session) @pytest.mark.parametrize("mode", ["poke", "reschedule"]) @@ -2116,7 +2066,8 @@ def timeout(): retries=retries, mode=mode, ) - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + ti.task = task with pytest.raises(AirflowSensorTimeout): ti.run() diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py index fc4d02497a9fa..aacfa8859e150 100644 --- a/tests/models/test_trigger.py +++ b/tests/models/test_trigger.py @@ -42,7 +42,7 @@ def clear_db(session): session.commit() -def test_clean_unused(session): +def test_clean_unused(session, create_task_instance): """ Tests that unused triggers (those with no task instances referencing them) are cleaned out automatically. @@ -60,11 +60,14 @@ def test_clean_unused(session): session.commit() assert session.query(Trigger).count() == 3 # Tie one to a fake TaskInstance that is not deferred, and one to one that is - fake_task = DummyOperator(task_id="fake") - task_instance = TaskInstance(task=fake_task, execution_date=timezone.utcnow(), state=State.DEFERRED) + task_instance = create_task_instance( + session=session, task_id="fake", state=State.DEFERRED, execution_date=timezone.utcnow() + ) task_instance.trigger_id = trigger1.id session.add(task_instance) - task_instance = TaskInstance(task=fake_task, execution_date=timezone.utcnow(), state=State.SUCCESS) + fake_task = DummyOperator(task_id="fake2", dag=task_instance.task.dag) + task_instance = TaskInstance(task=fake_task, run_id=task_instance.run_id) + task_instance.state = State.SUCCESS task_instance.trigger_id = trigger2.id session.add(task_instance) session.commit() @@ -74,7 +77,7 @@ def test_clean_unused(session): assert session.query(Trigger).one().id == trigger1.id -def test_submit_event(session): +def test_submit_event(session, create_task_instance): """ Tests that events submitted to a trigger re-wake their dependent task instances. @@ -85,11 +88,11 @@ def test_submit_event(session): session.add(trigger) session.commit() # Make a TaskInstance that's deferred and waiting on it - fake_task = DummyOperator(task_id="fake") - task_instance = TaskInstance(task=fake_task, execution_date=timezone.utcnow(), state=State.DEFERRED) + task_instance = create_task_instance( + session=session, execution_date=timezone.utcnow(), state=State.DEFERRED + ) task_instance.trigger_id = trigger.id task_instance.next_kwargs = {"cheesecake": True} - session.add(task_instance) session.commit() # Call submit_event Trigger.submit_event(trigger.id, TriggerEvent(42), session=session) @@ -99,7 +102,7 @@ def test_submit_event(session): assert updated_task_instance.next_kwargs == {"event": 42, "cheesecake": True} -def test_submit_failure(session): +def test_submit_failure(session, create_task_instance): """ Tests that failures submitted to a trigger fail their dependent task instances. @@ -110,10 +113,10 @@ def test_submit_failure(session): session.add(trigger) session.commit() # Make a TaskInstance that's deferred and waiting on it - fake_task = DummyOperator(task_id="fake") - task_instance = TaskInstance(task=fake_task, execution_date=timezone.utcnow(), state=State.DEFERRED) + task_instance = create_task_instance( + task_id="fake", execution_date=timezone.utcnow(), state=State.DEFERRED + ) task_instance.trigger_id = trigger.id - session.add(task_instance) session.commit() # Call submit_event Trigger.submit_failure(trigger.id, session=session) diff --git a/tests/operators/test_latest_only_operator.py b/tests/operators/test_latest_only_operator.py index ee0a0a4f0044c..90714afb71fb3 100644 --- a/tests/operators/test_latest_only_operator.py +++ b/tests/operators/test_latest_only_operator.py @@ -42,8 +42,9 @@ def get_task_instances(task_id): session = settings.Session() return ( session.query(TaskInstance) + .join(TaskInstance.dag_run) .filter(TaskInstance.task_id == task_id) - .order_by(TaskInstance.execution_date) + .order_by(DagRun.execution_date) .all() ) diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index 9357affb9fddc..6ca821ab83769 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -344,60 +344,6 @@ def tearDown(self): session.query(DagRun).delete() session.query(TI).delete() - def test_without_dag_run(self): - """This checks the defensive against non existent tasks in a dag run""" - branch_op = BranchPythonOperator( - task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1' - ) - self.branch_1.set_upstream(branch_op) - self.branch_2.set_upstream(branch_op) - self.dag.clear() - - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - with create_session() as session: - tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) - - for ti in tis: - if ti.task_id == 'make_choice': - assert ti.state == State.SUCCESS - elif ti.task_id == 'branch_1': - # should exist with state None - assert ti.state == State.NONE - elif ti.task_id == 'branch_2': - assert ti.state == State.SKIPPED - else: - raise ValueError(f'Invalid task id {ti.task_id} found!') - - def test_branch_list_without_dag_run(self): - """This checks if the BranchPythonOperator supports branching off to a list of tasks.""" - branch_op = BranchPythonOperator( - task_id='make_choice', dag=self.dag, python_callable=lambda: ['branch_1', 'branch_2'] - ) - self.branch_1.set_upstream(branch_op) - self.branch_2.set_upstream(branch_op) - self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) - self.branch_3.set_upstream(branch_op) - self.dag.clear() - - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - with create_session() as session: - tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) - - expected = { - "make_choice": State.SUCCESS, - "branch_1": State.NONE, - "branch_2": State.NONE, - "branch_3": State.SKIPPED, - } - - for ti in tis: - if ti.task_id in expected: - assert ti.state == expected[ti.task_id] - else: - raise ValueError(f'Invalid task id {ti.task_id} found!') - def test_with_dag_run(self): branch_op = BranchPythonOperator( task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1' @@ -581,54 +527,6 @@ def tearDown(self): session.query(DagRun).delete() session.query(TI).delete() - def test_without_dag_run(self): - """This checks the defensive against non existent tasks in a dag run""" - value = False - dag = DAG( - 'shortcircuit_operator_test_without_dag_run', - default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, - schedule_interval=INTERVAL, - ) - short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) - branch_1 = DummyOperator(task_id='branch_1', dag=dag) - branch_1.set_upstream(short_op) - branch_2 = DummyOperator(task_id='branch_2', dag=dag) - branch_2.set_upstream(branch_1) - upstream = DummyOperator(task_id='upstream', dag=dag) - upstream.set_downstream(short_op) - dag.clear() - - short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - with create_session() as session: - tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date == DEFAULT_DATE) - - for ti in tis: - if ti.task_id == 'make_choice': - assert ti.state == State.SUCCESS - elif ti.task_id == 'upstream': - # should not exist - raise ValueError(f'Invalid task id {ti.task_id} found!') - elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': - assert ti.state == State.SKIPPED - else: - raise ValueError(f'Invalid task id {ti.task_id} found!') - - value = True - dag.clear() - - short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - for ti in tis: - if ti.task_id == 'make_choice': - assert ti.state == State.SUCCESS - elif ti.task_id == 'upstream': - # should not exist - raise ValueError(f'Invalid task id {ti.task_id} found!') - elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': - assert ti.state == State.NONE - else: - raise ValueError(f'Invalid task id {ti.task_id} found!') - def test_with_dag_run(self): value = False dag = DAG( @@ -1144,11 +1042,11 @@ def test_get_context_in_old_style_context_task(self): ("join", [State.SUCCESS, State.SKIPPED, State.SUCCESS]), ], ) -def test_empty_branch(choice, expected_states): +def test_empty_branch(dag_maker, choice, expected_states): """ Tests that BranchPythonOperator handles empty branches properly. """ - with DAG( + with dag_maker( 'test_empty_branch', start_date=DEFAULT_DATE, ) as dag: @@ -1160,13 +1058,14 @@ def test_empty_branch(choice, expected_states): task1 >> join dag.clear(start_date=DEFAULT_DATE) + dag_run = dag_maker.create_dagrun() task_ids = ["branch", "task1", "join"] + tis = {ti.task_id: ti for ti in dag_run.task_instances} - tis = {} - for task_id in task_ids: - task_instance = TI(dag.get_task(task_id), execution_date=DEFAULT_DATE) - tis[task_id] = task_instance + for task_id in task_ids: # Mimic the specific order the scheduling would run the tests. + task_instance = tis[task_id] + task_instance.refresh_from_task(dag.get_task(task_id)) task_instance.run() def get_state(ti): diff --git a/tests/operators/test_subdag_operator.py b/tests/operators/test_subdag_operator.py index dcdb508e51784..1ea387eaeaad0 100644 --- a/tests/operators/test_subdag_operator.py +++ b/tests/operators/test_subdag_operator.py @@ -16,12 +16,10 @@ # specific language governing permissions and limitations # under the License. -import unittest from unittest import mock from unittest.mock import Mock import pytest -from parameterized import parameterized import airflow from airflow.exceptions import AirflowException @@ -36,14 +34,11 @@ DEFAULT_DATE = datetime(2016, 1, 1) -default_args = dict( - owner='airflow', - start_date=DEFAULT_DATE, -) +default_args = {"start_date": DEFAULT_DATE} -class TestSubDagOperator(unittest.TestCase): - def setUp(self): +class TestSubDagOperator: + def setup_method(self): clear_db_runs() self.dag_run_running = DagRun() self.dag_run_running.state = State.RUNNING @@ -52,6 +47,9 @@ def setUp(self): self.dag_run_failed = DagRun() self.dag_run_failed.state = State.FAILED + def teardown_class(self): + clear_db_runs() + def test_subdag_name(self): """ Subdag names must be {parent_dag}.{subdag task} @@ -256,31 +254,28 @@ def test_execute_skip_if_dagrun_success(self): subdag.create_dagrun.assert_not_called() assert 3 == len(subdag_task._get_dagrun.mock_calls) - def test_rerun_failed_subdag(self): + def test_rerun_failed_subdag(self, dag_maker): """ When there is an existing DagRun with failed state, reset the DagRun and the corresponding TaskInstances """ - dag = DAG('parent', default_args=default_args) - subdag = DAG('parent.test', default_args=default_args) - subdag_task = SubDagOperator(task_id='test', subdag=subdag, dag=dag, poke_interval=1) - dummy_task = DummyOperator(task_id='dummy', dag=subdag) - with create_session() as session: - dummy_task_instance = TaskInstance( - task=dummy_task, + with dag_maker('parent.test', default_args=default_args, session=session) as subdag: + dummy_task = DummyOperator(task_id='dummy') + sub_dagrun = dag_maker.create_dagrun( + run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, state=State.FAILED, + external_trigger=True, ) - session.add(dummy_task_instance) - session.commit() - sub_dagrun = subdag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=DEFAULT_DATE, - state=State.FAILED, - external_trigger=True, - ) + (dummy_task_instance,) = sub_dagrun.task_instances + dummy_task_instance.refresh_from_task(dummy_task) + dummy_task_instance.state == State.FAILED + + with dag_maker('parent', default_args=default_args, session=session): + subdag_task = SubDagOperator(task_id='test', subdag=subdag, poke_interval=1) + dag_maker.create_dagrun(execution_date=DEFAULT_DATE, run_type=DagRunType.SCHEDULED) subdag_task._reset_dag_run_and_task_instances(sub_dagrun, execution_date=DEFAULT_DATE) @@ -290,43 +285,54 @@ def test_rerun_failed_subdag(self): sub_dagrun.refresh_from_db() assert sub_dagrun.state == State.RUNNING - @parameterized.expand( + @pytest.mark.parametrize( + "propagate_option, states, skip_parent", [ (SkippedStatePropagationOptions.ALL_LEAVES, [State.SKIPPED, State.SKIPPED], True), (SkippedStatePropagationOptions.ALL_LEAVES, [State.SKIPPED, State.SUCCESS], False), (SkippedStatePropagationOptions.ANY_LEAF, [State.SKIPPED, State.SUCCESS], True), (SkippedStatePropagationOptions.ANY_LEAF, [State.FAILED, State.SKIPPED], True), (None, [State.SKIPPED, State.SKIPPED], False), - ] + ], ) @mock.patch('airflow.operators.subdag.SubDagOperator.skip') @mock.patch('airflow.operators.subdag.get_task_instance') def test_subdag_with_propagate_skipped_state( - self, propagate_option, states, skip_parent, mock_get_task_instance, mock_skip + self, + mock_get_task_instance, + mock_skip, + dag_maker, + propagate_option, + states, + skip_parent, ): """ Tests that skipped state of leaf tasks propagates to the parent dag. Note that the skipped state propagation only takes affect when the dagrun's state is SUCCESS. """ - dag = DAG('parent', default_args=default_args) - subdag = DAG('parent.test', default_args=default_args) - subdag_task = SubDagOperator( - task_id='test', subdag=subdag, dag=dag, poke_interval=1, propagate_skipped_state=propagate_option - ) - dummy_subdag_tasks = [ - DummyOperator(task_id=f'dummy_subdag_{i}', dag=subdag) for i in range(len(states)) - ] - dummy_dag_task = DummyOperator(task_id='dummy_dag', dag=dag) - subdag_task >> dummy_dag_task + with dag_maker('parent.test', default_args=default_args) as subdag: + dummy_subdag_tasks = [DummyOperator(task_id=f'dummy_subdag_{i}') for i in range(len(states))] + dag_maker.create_dagrun(execution_date=DEFAULT_DATE) + + with dag_maker('parent', default_args=default_args): + subdag_task = SubDagOperator( + task_id='test', + subdag=subdag, + poke_interval=1, + propagate_skipped_state=propagate_option, + ) + dummy_dag_task = DummyOperator(task_id='dummy_dag') + subdag_task >> dummy_dag_task + dag_run = dag_maker.create_dagrun(execution_date=DEFAULT_DATE) + + subdag_task._get_dagrun = Mock(return_value=self.dag_run_success) - subdag_task._get_dagrun = Mock() - subdag_task._get_dagrun.return_value = self.dag_run_success mock_get_task_instance.side_effect = [ - TaskInstance(task=task, execution_date=DEFAULT_DATE, state=state) + TaskInstance(task=task, run_id=dag_run.run_id, state=state) for task, state in zip(dummy_subdag_tasks, states) ] - context = {'execution_date': DEFAULT_DATE, 'dag_run': DagRun(), 'task': subdag_task} + context = {'execution_date': DEFAULT_DATE, 'dag_run': dag_run, 'task': subdag_task} subdag_task.post_execute(context) if skip_parent: diff --git a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py index 8acc7398d2fee..db8ddcd7c6953 100644 --- a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py +++ b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py @@ -22,7 +22,7 @@ from watchtower import CloudWatchLogHandler -from airflow.models import DAG, TaskInstance +from airflow.models import DAG, DagRun, TaskInstance from airflow.operators.dummy import DummyOperator from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.utils.log.cloudwatch_task_handler import CloudwatchTaskHandler @@ -55,11 +55,13 @@ def setUp(self): self.cloudwatch_task_handler.hook date = datetime(2020, 1, 1) - dag_id = 'dag_for_testing_file_task_handler' - task_id = 'task_for_testing_file_log_handler' + dag_id = 'dag_for_testing_cloudwatch_task_handler' + task_id = 'task_for_testing_cloudwatch_log_handler' self.dag = DAG(dag_id=dag_id, start_date=date) task = DummyOperator(task_id=task_id, dag=self.dag) - self.ti = TaskInstance(task=task, execution_date=date) + dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=date, run_id="test") + self.ti = TaskInstance(task=task) + self.ti.dag_run = dag_run self.ti.try_number = 1 self.ti.state = State.RUNNING diff --git a/tests/providers/amazon/aws/log/test_s3_task_handler.py b/tests/providers/amazon/aws/log/test_s3_task_handler.py index 323bef4f4e636..b931242a14275 100644 --- a/tests/providers/amazon/aws/log/test_s3_task_handler.py +++ b/tests/providers/amazon/aws/log/test_s3_task_handler.py @@ -24,7 +24,7 @@ import pytest from botocore.exceptions import ClientError -from airflow.models import DAG, TaskInstance +from airflow.models import DAG, DagRun, TaskInstance from airflow.operators.dummy import DummyOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.log.s3_task_handler import S3TaskHandler @@ -58,9 +58,11 @@ def setUp(self): assert self.s3_task_handler.hook is not None date = datetime(2016, 1, 1) - self.dag = DAG('dag_for_testing_file_task_handler', start_date=date) - task = DummyOperator(task_id='task_for_testing_file_log_handler', dag=self.dag) - self.ti = TaskInstance(task=task, execution_date=date) + self.dag = DAG('dag_for_testing_s3_task_handler', start_date=date) + task = DummyOperator(task_id='task_for_testing_s3_log_handler', dag=self.dag) + dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=date, run_id="test") + self.ti = TaskInstance(task=task) + self.ti.dag_run = dag_run self.ti.try_number = 1 self.ti.state = State.RUNNING self.addCleanup(self.dag.clear) diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index 263db769cced0..97ec4d205065d 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -20,7 +20,7 @@ import pytest -from airflow.models import DAG, TaskInstance +from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook from airflow.providers.amazon.aws.operators.athena import AWSAthenaOperator from airflow.utils import timezone @@ -205,8 +205,10 @@ def test_hook_run_failed_query_with_max_tries(self, mock_conn, mock_run_query, m @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",)) @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) @mock.patch.object(AWSAthenaHook, 'get_conn') - def test_xcom_push_and_pull(self, mock_conn, mock_run_query, mock_check_query_status): - ti = TaskInstance(task=self.athena, execution_date=timezone.utcnow()) - ti.run() + def test_return_value(self, mock_conn, mock_run_query, mock_check_query_status): + """Test we return the right value -- that will get put in to XCom by the execution engine""" + dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=timezone.utcnow(), run_id="test") + ti = TaskInstance(task=self.athena) + ti.dag_run = dag_run - assert ti.xcom_pull(task_ids='test_aws_athena_operator') == ATHENA_QUERY_ID + assert self.athena.execute(ti.get_template_context()) == ATHENA_QUERY_ID diff --git a/tests/providers/amazon/aws/operators/test_datasync.py b/tests/providers/amazon/aws/operators/test_datasync.py index 6a4f9111ca7ae..774e0353b442e 100644 --- a/tests/providers/amazon/aws/operators/test_datasync.py +++ b/tests/providers/amazon/aws/operators/test_datasync.py @@ -22,7 +22,7 @@ from moto import mock_datasync from airflow.exceptions import AirflowException -from airflow.models import DAG, TaskInstance +from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.hooks.datasync import AWSDataSyncHook from airflow.providers.amazon.aws.operators.datasync import AWSDataSyncOperator from airflow.utils import timezone @@ -306,16 +306,17 @@ def test_execute_specific_task(self, mock_get_conn): # ### Check mocks: mock_get_conn.assert_called() - def test_xcom_push(self, mock_get_conn): + def test_return_value(self, mock_get_conn): + """Test we return the right value -- that will get put in to XCom by the execution engine""" # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() - ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow()) - ti.run() - xcom_result = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value") - assert xcom_result is not None + dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=timezone.utcnow(), run_id="test") + ti = TaskInstance(task=self.datasync) + ti.dag_run = dag_run + assert self.datasync.execute(ti.get_template_context()) is not None # ### Check mocks: mock_get_conn.assert_called() @@ -497,16 +498,18 @@ def test_execute_specific_task(self, mock_get_conn): # ### Check mocks: mock_get_conn.assert_called() - def test_xcom_push(self, mock_get_conn): + def test_return_value(self, mock_get_conn): + """Test we return the right value -- that will get put in to XCom by the execution engine""" # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() - ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow()) - ti.run() - pushed_task_arn = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value")["TaskArn"] - assert pushed_task_arn == self.task_arn + dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=timezone.utcnow(), run_id="test") + ti = TaskInstance(task=self.datasync) + ti.dag_run = dag_run + result = self.datasync.execute(ti.get_template_context()) + assert result["TaskArn"] == self.task_arn # ### Check mocks: mock_get_conn.assert_called() @@ -597,16 +600,18 @@ def test_execute_specific_task(self, mock_get_conn): # ### Check mocks: mock_get_conn.assert_called() - def test_xcom_push(self, mock_get_conn): + def test_return_value(self, mock_get_conn): + """Test we return the right value -- that will get put in to XCom by the execution engine""" # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() - ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow()) - ti.run() - pushed_task_arn = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value")["TaskArn"] - assert pushed_task_arn == self.task_arn + dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=timezone.utcnow(), run_id="test") + ti = TaskInstance(task=self.datasync) + ti.dag_run = dag_run + result = self.datasync.execute(ti.get_template_context()) + assert result["TaskArn"] == self.task_arn # ### Check mocks: mock_get_conn.assert_called() @@ -750,16 +755,17 @@ def test_execute_specific_task(self, mock_get_conn): # ### Check mocks: mock_get_conn.assert_called() - def test_xcom_push(self, mock_get_conn): + def test_return_value(self, mock_get_conn): + """Test we return the right value -- that will get put in to XCom by the execution engine""" # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() - ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow()) - ti.run() - xcom_result = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value") - assert xcom_result is not None + dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=timezone.utcnow(), run_id="test") + ti = TaskInstance(task=self.datasync) + ti.dag_run = dag_run + assert self.datasync.execute(ti.get_template_context()) is not None # ### Check mocks: mock_get_conn.assert_called() @@ -843,15 +849,17 @@ def test_execute_specific_task(self, mock_get_conn): # ### Check mocks: mock_get_conn.assert_called() - def test_xcom_push(self, mock_get_conn): + def test_return_value(self, mock_get_conn): + """Test we return the right value -- that will get put in to XCom by the execution engine""" # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() - ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow()) - ti.run() - pushed_task_arn = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value")["TaskArn"] - assert pushed_task_arn == self.task_arn + dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=timezone.utcnow(), run_id="test") + ti = TaskInstance(task=self.datasync) + ti.dag_run = dag_run + result = self.datasync.execute(ti.get_template_context()) + assert result["TaskArn"] == self.task_arn # ### Check mocks: mock_get_conn.assert_called() diff --git a/tests/providers/amazon/aws/operators/test_dms_describe_tasks.py b/tests/providers/amazon/aws/operators/test_dms_describe_tasks.py index 587620cd9cabd..a099ebe494187 100644 --- a/tests/providers/amazon/aws/operators/test_dms_describe_tasks.py +++ b/tests/providers/amazon/aws/operators/test_dms_describe_tasks.py @@ -18,7 +18,7 @@ import unittest from unittest import mock -from airflow.models import DAG, TaskInstance +from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.hooks.dms import DmsHook from airflow.providers.amazon.aws.operators.dms_describe_tasks import DmsDescribeTasksOperator from airflow.utils import timezone @@ -88,9 +88,10 @@ def test_describe_tasks_return_value(self, mock_conn, mock_describe_replication_ task_id='describe_tasks', dag=self.dag, describe_tasks_kwargs={'Filters': [FILTER]} ) - ti = TaskInstance(task=describe_task, execution_date=timezone.utcnow()) - ti.run() - marker, response = ti.xcom_pull(task_ids=describe_task.task_id, key="return_value") + dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=timezone.utcnow(), run_id="test") + ti = TaskInstance(task=describe_task) + ti.dag_run = dag_run + marker, response = describe_task.execute(ti.get_template_context()) assert marker is None assert response == MOCK_RESPONSE diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py index 77aef3e314159..f51dc0526e76b 100644 --- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py +++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. +import json import os import unittest from datetime import timedelta @@ -25,9 +26,7 @@ from jinja2 import StrictUndefined from airflow.exceptions import AirflowException -from airflow.models import TaskInstance -from airflow.models.dag import DAG -from airflow.operators.dummy import DummyOperator +from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.operators.emr_add_steps import EmrAddStepsOperator from airflow.utils import timezone from tests.test_utils import AIRFLOW_MAIN_FOLDER @@ -80,7 +79,9 @@ def test_init(self): assert self.operator.aws_conn_id == 'aws_default' def test_render_template(self): - ti = TaskInstance(self.operator, DEFAULT_DATE) + dag_run = DagRun(dag_id=self.operator.dag.dag_id, execution_date=DEFAULT_DATE, run_id="test") + ti = TaskInstance(task=self.operator) + ti.dag_run = dag_run ti.render_templates() expected_args = [ @@ -100,45 +101,6 @@ def test_render_template(self): assert self.operator.steps == expected_args - def test_render_template_2(self): - dag = DAG(dag_id='test_xcom', default_args=self.args) - - xcom_steps = [ - { - 'Name': 'test_step1', - 'ActionOnFailure': 'CONTINUE', - 'HadoopJarStep': {'Jar': 'command-runner.jar', 'Args': ['/usr/lib/spark/bin/run-example1']}, - }, - { - 'Name': 'test_step2', - 'ActionOnFailure': 'CONTINUE', - 'HadoopJarStep': {'Jar': 'command-runner.jar', 'Args': ['/usr/lib/spark/bin/run-example2']}, - }, - ] - - make_steps = DummyOperator(task_id='make_steps', dag=dag, owner='airflow') - execution_date = timezone.utcnow() - ti1 = TaskInstance(task=make_steps, execution_date=execution_date) - ti1.xcom_push(key='steps', value=xcom_steps) - - self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN - - test_task = EmrAddStepsOperator( - task_id='test_task', - job_flow_id='j-8989898989', - aws_conn_id='aws_default', - steps="{{ ti.xcom_pull(task_ids='make_steps',key='steps') }}", - dag=dag, - ) - - with patch('boto3.session.Session', self.boto3_session_mock): - ti = TaskInstance(task=test_task, execution_date=execution_date) - ti.run() - - self.emr_client_mock.add_job_flow_steps.assert_called_once_with( - JobFlowId='j-8989898989', Steps=xcom_steps - ) - def test_render_template_from_file(self): dag = DAG( dag_id='test_file', @@ -155,8 +117,6 @@ def test_render_template_from_file(self): } ] - execution_date = timezone.utcnow() - self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN test_task = EmrAddStepsOperator( @@ -165,11 +125,18 @@ def test_render_template_from_file(self): aws_conn_id='aws_default', steps='steps.j2.json', dag=dag, + do_xcom_push=False, ) + dag_run = DagRun(dag_id=dag.dag_id, execution_date=timezone.utcnow(), run_id="test") + ti = TaskInstance(task=test_task) + ti.dag_run = dag_run + ti.render_templates() + + assert json.loads(test_task.steps) == file_steps + # String in job_flow_overrides (i.e. from loaded as a file) is not "parsed" until inside execute() with patch('boto3.session.Session', self.boto3_session_mock): - ti = TaskInstance(task=test_task, execution_date=execution_date) - ti.run() + test_task.execute(None) self.emr_client_mock.add_job_flow_steps.assert_called_once_with( JobFlowId='j-8989898989', Steps=file_steps diff --git a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py index 7e8ad2c72b211..73b6f22a8bbac 100644 --- a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py +++ b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py @@ -24,8 +24,7 @@ from jinja2 import StrictUndefined -from airflow.models import TaskInstance -from airflow.models.dag import DAG +from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.operators.emr_create_job_flow import EmrCreateJobFlowOperator from airflow.utils import timezone from tests.test_utils import AIRFLOW_MAIN_FOLDER @@ -81,7 +80,9 @@ def test_init(self): def test_render_template(self): self.operator.job_flow_overrides = self._config - ti = TaskInstance(self.operator, DEFAULT_DATE) + dag_run = DagRun(dag_id=self.operator.dag_id, execution_date=DEFAULT_DATE, run_id="test") + ti = TaskInstance(task=self.operator) + ti.dag_run = dag_run ti.render_templates() expected_args = { @@ -109,7 +110,9 @@ def test_render_template_from_file(self): self.operator.job_flow_overrides = 'job.j2.json' self.operator.params = {'releaseLabel': '5.11.0'} - ti = TaskInstance(self.operator, DEFAULT_DATE) + dag_run = DagRun(dag_id=self.operator.dag_id, execution_date=DEFAULT_DATE, run_id="test") + ti = TaskInstance(task=self.operator) + ti.dag_run = dag_run ti.render_templates() self.emr_client_mock.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN @@ -117,6 +120,7 @@ def test_render_template_from_file(self): emr_session_mock.client.return_value = self.emr_client_mock boto3_session_mock = MagicMock(return_value=emr_session_mock) + # String in job_flow_overrides (i.e. from loaded as a file) is not "parsed" until inside execute() with patch('boto3.session.Session', boto3_session_mock): self.operator.execute(None) diff --git a/tests/providers/amazon/aws/sensors/test_s3_key.py b/tests/providers/amazon/aws/sensors/test_s3_key.py index a675bdf5325fc..d8d4ee4b85d99 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_key.py +++ b/tests/providers/amazon/aws/sensors/test_s3_key.py @@ -24,8 +24,7 @@ from parameterized import parameterized from airflow.exceptions import AirflowException -from airflow.models import TaskInstance -from airflow.models.dag import DAG +from airflow.models import DAG, DagRun, TaskInstance from airflow.models.variable import Variable from airflow.providers.amazon.aws.sensors.s3_key import S3KeySensor, S3KeySizeSensor @@ -90,7 +89,9 @@ def test_parse_bucket_key_from_jinja(self, mock_hook): dag=dag, ) - ti = TaskInstance(task=op, execution_date=execution_date) + dag_run = DagRun(dag_id=dag.dag_id, execution_date=execution_date, run_id="test") + ti = TaskInstance(task=op) + ti.dag_run = dag_run context = ti.get_template_context() ti.render_templates(context) diff --git a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py index 10a8e20357c79..bac5eac1a3f12 100644 --- a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py @@ -18,8 +18,7 @@ import unittest from unittest import mock -from airflow.models import TaskInstance -from airflow.models.dag import DAG +from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.transfers.mongo_to_s3 import MongoToS3Operator from airflow.utils import timezone @@ -76,7 +75,9 @@ def test_template_field_overrides(self): ) def test_render_template(self): - ti = TaskInstance(self.mock_operator, DEFAULT_DATE) + dag_run = DagRun(dag_id=self.mock_operator.dag_id, execution_date=DEFAULT_DATE, run_id="test") + ti = TaskInstance(task=self.mock_operator) + ti.dag_run = dag_run ti.render_templates() expected_rendered_template = {'$lt': '2017-01-01T00:00:00+00:00Z'} diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py b/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py index e7255c328b4d9..7d4330235fa11 100644 --- a/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py +++ b/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py @@ -20,10 +20,9 @@ import boto3 from moto import mock_s3 -from airflow.models import DAG, TaskInstance +from airflow.models import DAG from airflow.providers.amazon.aws.transfers.s3_to_sftp import S3ToSFTPOperator from airflow.providers.ssh.operators.ssh import SSHOperator -from airflow.utils import timezone from airflow.utils.timezone import datetime from tests.test_utils.config import conf_vars @@ -117,11 +116,8 @@ def test_s3_to_sftp_operation(self): dag=self.dag, ) assert check_file_task is not None - ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) - ti3.run() - assert ti3.xcom_pull( - task_ids='test_check_file', key='return_value' - ).strip() == test_remote_file_content.encode('utf-8') + result = check_file_task.execute(None) + assert result.strip() == test_remote_file_content.encode('utf-8') # Clean up after finishing with test conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key) @@ -138,8 +134,7 @@ def delete_remote_resource(self): dag=self.dag, ) assert remove_file_task is not None - ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow()) - ti3.run() + remove_file_task.execute(None) def tearDown(self): self.delete_remote_resource() diff --git a/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py b/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py index 409809e059a85..8a62bf2b5bb33 100644 --- a/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py @@ -21,12 +21,11 @@ import boto3 from moto import mock_s3 -from airflow.models import DAG, TaskInstance +from airflow.models import DAG from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.transfers.sftp_to_s3 import SFTPToS3Operator from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.providers.ssh.operators.ssh import SSHOperator -from airflow.utils import timezone from airflow.utils.timezone import datetime from tests.test_utils.config import conf_vars @@ -85,8 +84,7 @@ def test_sftp_to_s3_operation(self): dag=self.dag, ) assert create_file_task is not None - ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) - ti1.run() + create_file_task.execute(None) # Test for creation of s3 bucket conn = boto3.client('s3') diff --git a/tests/providers/apache/druid/operators/test_druid.py b/tests/providers/apache/druid/operators/test_druid.py index c366e695e4885..55867d5a34fa3 100644 --- a/tests/providers/apache/druid/operators/test_druid.py +++ b/tests/providers/apache/druid/operators/test_druid.py @@ -16,76 +16,60 @@ # specific language governing permissions and limitations # under the License. # +import json -import os -import unittest -from tempfile import NamedTemporaryFile - -from airflow.models import TaskInstance -from airflow.models.dag import DAG from airflow.providers.apache.druid.operators.druid import DruidOperator from airflow.utils import timezone DEFAULT_DATE = timezone.datetime(2017, 1, 1) - -class TestDruidOperator(unittest.TestCase): - def setUp(self): - args = {'owner': 'airflow', 'start_date': timezone.datetime(2017, 1, 1)} - self.dag = DAG('test_dag_id', default_args=args) - self.json_index_str = ''' - { - "type": "{{ params.index_type }}", - "datasource": "{{ params.datasource }}", - "spec": { - "dataSchema": { - "granularitySpec": { - "intervals": ["{{ ds }}/{{ macros.ds_add(ds, 1) }}"] - } - } - } - } - ''' - self.rendered_index_str = ''' - { - "type": "index_hadoop", - "datasource": "datasource_prd", - "spec": { - "dataSchema": { - "granularitySpec": { - "intervals": ["2017-01-01/2017-01-02"] - } - } +JSON_INDEX_STR = """ + { + "type": "{{ params.index_type }}", + "datasource": "{{ params.datasource }}", + "spec": { + "dataSchema": { + "granularitySpec": { + "intervals": ["{{ ds }}/{{ macros.ds_add(ds, 1) }}"] } } - ''' + } + } +""" - def test_render_template(self): +RENDERED_INDEX = { + "type": "index_hadoop", + "datasource": "datasource_prd", + "spec": {"dataSchema": {"granularitySpec": {"intervals": ["2017-01-01/2017-01-02"]}}}, +} + + +def test_render_template(dag_maker): + with dag_maker("test_druid_render_template", default_args={"start_date": DEFAULT_DATE}): operator = DruidOperator( - task_id='spark_submit_job', - json_index_file=self.json_index_str, - params={'index_type': 'index_hadoop', 'datasource': 'datasource_prd'}, - dag=self.dag, + task_id="spark_submit_job", + json_index_file=JSON_INDEX_STR, + params={"index_type": "index_hadoop", "datasource": "datasource_prd"}, ) - ti = TaskInstance(operator, DEFAULT_DATE) - ti.render_templates() - assert self.rendered_index_str == operator.json_index_file + dag_maker.create_dagrun().task_instances[0].render_templates() + assert RENDERED_INDEX == json.loads(operator.json_index_file) - def test_render_template_from_file(self): - with NamedTemporaryFile("w", suffix='.json') as f: - f.write(self.json_index_str) - f.flush() - self.dag.template_searchpath = os.path.dirname(f.name) +def test_render_template_from_file(tmp_path, dag_maker): + json_index_file = tmp_path.joinpath("json_index.json") + json_index_file.write_text(JSON_INDEX_STR) - operator = DruidOperator( - task_id='spark_submit_job', - json_index_file=f.name, - params={'index_type': 'index_hadoop', 'datasource': 'datasource_prd'}, - dag=self.dag, - ) - ti = TaskInstance(operator, DEFAULT_DATE) - ti.render_templates() + with dag_maker( + "test_druid_render_template_from_file", + template_searchpath=[str(tmp_path)], + default_args={"start_date": DEFAULT_DATE}, + ): + operator = DruidOperator( + task_id="spark_submit_job", + json_index_file=json_index_file.name, + params={"index_type": "index_hadoop", "datasource": "datasource_prd"}, + ) - assert self.rendered_index_str == operator.json_index_file + dag_maker.create_dagrun().task_instances[0].render_templates() + assert RENDERED_INDEX == json.loads(operator.json_index_file) diff --git a/tests/providers/apache/hive/operators/test_hive.py b/tests/providers/apache/hive/operators/test_hive.py index 548c96c773871..944d194dfc0b6 100644 --- a/tests/providers/apache/hive/operators/test_hive.py +++ b/tests/providers/apache/hive/operators/test_hive.py @@ -21,7 +21,7 @@ from unittest import mock from airflow.configuration import conf -from airflow.models import TaskInstance +from airflow.models import DagRun, TaskInstance from airflow.providers.apache.hive.operators.hive import HiveOperator from airflow.utils import timezone from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment @@ -83,8 +83,10 @@ def test_mapred_job_name(self, mock_get_hook): mock_get_hook.return_value = mock_hook op = MockHiveOperator(task_id='test_mapred_job_name', hql=self.hql, dag=self.dag) + fake_dagrun_id = "test_mapred_job_name" fake_execution_date = timezone.datetime(2018, 6, 19) - fake_ti = TaskInstance(task=op, execution_date=fake_execution_date) + fake_ti = TaskInstance(task=op) + fake_ti.dag_run = DagRun(run_id=fake_dagrun_id, execution_date=fake_execution_date) fake_ti.hostname = 'fake_hostname' fake_context = {'ti': fake_ti} diff --git a/tests/providers/apache/kylin/operators/test_kylin_cube.py b/tests/providers/apache/kylin/operators/test_kylin_cube.py index 4ad927477378e..7d306ad84e9d2 100644 --- a/tests/providers/apache/kylin/operators/test_kylin_cube.py +++ b/tests/providers/apache/kylin/operators/test_kylin_cube.py @@ -23,7 +23,7 @@ import pytest from airflow.exceptions import AirflowException -from airflow.models import TaskInstance +from airflow.models import DagRun, TaskInstance from airflow.models.dag import DAG from airflow.providers.apache.kylin.operators.kylin_cube import KylinCubeOperator from airflow.utils import timezone @@ -166,7 +166,8 @@ def test_render_template(self): 'end_time': '1483286400000', }, ) - ti = TaskInstance(operator, DEFAULT_DATE) + ti = TaskInstance(operator, run_id="kylin_test") + ti.dag_run = DagRun(run_id="kylin_test", execution_date=DEFAULT_DATE) ti.render_templates() assert 'learn_kylin' == getattr(operator, 'project') assert 'kylin_sales_cube' == getattr(operator, 'cube') diff --git a/tests/providers/apache/spark/operators/test_spark_submit.py b/tests/providers/apache/spark/operators/test_spark_submit.py index dba7a226b343e..746bcd5b6c628 100644 --- a/tests/providers/apache/spark/operators/test_spark_submit.py +++ b/tests/providers/apache/spark/operators/test_spark_submit.py @@ -20,7 +20,7 @@ import unittest from datetime import timedelta -from airflow.models import TaskInstance +from airflow.models import DagRun, TaskInstance from airflow.models.dag import DAG from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator from airflow.utils import timezone @@ -147,7 +147,8 @@ def test_execute(self): def test_render_template(self): # Given operator = SparkSubmitOperator(task_id='spark_submit_job', dag=self.dag, **self._config) - ti = TaskInstance(operator, DEFAULT_DATE) + ti = TaskInstance(operator, run_id="spark_test") + ti.dag_run = DagRun(run_id="spark_test", execution_date=DEFAULT_DATE) # When ti.render_templates() diff --git a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py index b0f3aff36fa8d..47a20f8985882 100644 --- a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py @@ -22,7 +22,7 @@ from kubernetes.client import ApiClient, models as k8s from airflow.exceptions import AirflowException -from airflow.models import DAG, TaskInstance +from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator from airflow.utils import timezone from airflow.utils.state import State @@ -49,7 +49,8 @@ def setUp(self): @staticmethod def create_context(task): dag = DAG(dag_id="dag") - task_instance = TaskInstance(task=task, execution_date=DEFAULT_DATE) + task_instance = TaskInstance(task=task, run_id="kub_pod_test") + task_instance.dag_run = DagRun(run_id="kub_pod_test", execution_date=DEFAULT_DATE) return { "dag": dag, "ts": DEFAULT_DATE.isoformat(), @@ -651,7 +652,8 @@ def test_push_xcom_pod_info(self): do_xcom_push=False, ) pod = self.run_pod(k) - ti = TaskInstance(task=k, execution_date=DEFAULT_DATE) + ti = TaskInstance(task=k, run_id="test_push_xcom_pod_info") + ti.dag_run = DagRun(run_id="test_push_xcom_pod_info", execution_date=DEFAULT_DATE) pod_name = ti.xcom_pull(task_ids=k.task_id, key='pod_name') pod_namespace = ti.xcom_pull(task_ids=k.task_id, key='pod_namespace') assert pod_name and pod_name == pod.metadata.name diff --git a/tests/providers/elasticsearch/log/test_es_task_handler.py b/tests/providers/elasticsearch/log/test_es_task_handler.py index 22c228a15cfdb..004a8bcca4461 100644 --- a/tests/providers/elasticsearch/log/test_es_task_handler.py +++ b/tests/providers/elasticsearch/log/test_es_task_handler.py @@ -20,36 +20,48 @@ import logging import os import shutil -import unittest from unittest import mock from urllib.parse import quote import elasticsearch import freezegun import pendulum -from parameterized import parameterized +import pytest from airflow.configuration import conf -from airflow.models import DAG, TaskInstance -from airflow.operators.dummy import DummyOperator from airflow.providers.elasticsearch.log.es_task_handler import ElasticsearchTaskHandler from airflow.utils import timezone -from airflow.utils.state import State +from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.timezone import datetime +from tests.test_utils.db import clear_db_dags, clear_db_runs from .elasticmock import elasticmock -class TestElasticsearchTaskHandler(unittest.TestCase): - DAG_ID = 'dag_for_testing_file_task_handler' - TASK_ID = 'task_for_testing_file_log_handler' +class TestElasticsearchTaskHandler: + DAG_ID = 'dag_for_testing_es_task_handler' + TASK_ID = 'task_for_testing_es_log_handler' EXECUTION_DATE = datetime(2016, 1, 1) LOG_ID = f'{DAG_ID}-{TASK_ID}-2016-01-01T00:00:00+00:00-1' JSON_LOG_ID = f'{DAG_ID}-{TASK_ID}-{ElasticsearchTaskHandler._clean_execution_date(EXECUTION_DATE)}-1' + @pytest.fixture() + def ti(self, create_task_instance): + ti = create_task_instance( + dag_id=self.DAG_ID, + task_id=self.TASK_ID, + execution_date=self.EXECUTION_DATE, + dagrun_state=DagRunState.RUNNING, + state=TaskInstanceState.RUNNING, + ) + ti.try_number = 1 + ti.raw = False + yield ti + clear_db_runs() + clear_db_dags() + @elasticmock - def setUp(self): - super().setUp() + def setup(self): self.local_log_location = 'local/log/location' self.filename_template = '{try_number}.log' self.log_id_template = '{dag_id}-{task_id}-{execution_date}-{try_number}' @@ -76,17 +88,9 @@ def setUp(self): self.doc_type = 'log' self.test_message = 'some random stuff' self.body = {'message': self.test_message, 'log_id': self.LOG_ID, 'offset': 1} - self.es.index(index=self.index_name, doc_type=self.doc_type, body=self.body, id=1) - self.dag = DAG(self.DAG_ID, start_date=self.EXECUTION_DATE) - task = DummyOperator(task_id=self.TASK_ID, dag=self.dag) - self.ti = TaskInstance(task=task, execution_date=self.EXECUTION_DATE) - self.ti.try_number = 1 - self.ti.state = State.RUNNING - self.addCleanup(self.dag.clear) - - def tearDown(self): + def teardown(self): shutil.rmtree(self.local_log_location.split(os.path.sep)[0], ignore_errors=True) def test_client(self): @@ -113,10 +117,10 @@ def test_client_with_config(self): es_kwargs=es_conf, ) - def test_read(self): + def test_read(self, ti): ts = pendulum.now() logs, metadatas = self.es_task_handler.read( - self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} + ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} ) assert 1 == len(logs) @@ -127,7 +131,7 @@ def test_read(self): assert '1' == metadatas[0]['offset'] assert timezone.parse(metadatas[0]['last_log_timestamp']) > ts - def test_read_with_match_phrase_query(self): + def test_read_with_match_phrase_query(self, ti): similar_log_id = '{task_id}-{dag_id}-2016-01-01T00:00:00+00:00-1'.format( dag_id=TestElasticsearchTaskHandler.DAG_ID, task_id=TestElasticsearchTaskHandler.TASK_ID ) @@ -138,7 +142,7 @@ def test_read_with_match_phrase_query(self): ts = pendulum.now() logs, metadatas = self.es_task_handler.read( - self.ti, 1, {'offset': '0', 'last_log_timestamp': str(ts), 'end_of_log': False, 'max_offset': 2} + ti, 1, {'offset': '0', 'last_log_timestamp': str(ts), 'end_of_log': False, 'max_offset': 2} ) assert 1 == len(logs) assert len(logs) == len(metadatas) @@ -149,8 +153,8 @@ def test_read_with_match_phrase_query(self): assert '1' == metadatas[0]['offset'] assert timezone.parse(metadatas[0]['last_log_timestamp']) > ts - def test_read_with_none_metadata(self): - logs, metadatas = self.es_task_handler.read(self.ti, 1) + def test_read_with_none_metadata(self, ti): + logs, metadatas = self.es_task_handler.read(ti, 1) assert 1 == len(logs) assert len(logs) == len(metadatas) assert self.test_message == logs[0][0][-1] @@ -158,14 +162,14 @@ def test_read_with_none_metadata(self): assert '1' == metadatas[0]['offset'] assert timezone.parse(metadatas[0]['last_log_timestamp']) < pendulum.now() - def test_read_nonexistent_log(self): + def test_read_nonexistent_log(self, ti): ts = pendulum.now() # In ElasticMock, search is going to return all documents with matching index # and doc_type regardless of match filters, so we delete the log entry instead # of making a new TaskInstance to query. self.es.delete(index=self.index_name, doc_type=self.doc_type, id=1) logs, metadatas = self.es_task_handler.read( - self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} + ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} ) assert 1 == len(logs) assert len(logs) == len(metadatas) @@ -175,9 +179,9 @@ def test_read_nonexistent_log(self): # last_log_timestamp won't change if no log lines read. assert timezone.parse(metadatas[0]['last_log_timestamp']) == ts - def test_read_with_empty_metadata(self): + def test_read_with_empty_metadata(self, ti): ts = pendulum.now() - logs, metadatas = self.es_task_handler.read(self.ti, 1, {}) + logs, metadatas = self.es_task_handler.read(ti, 1, {}) assert 1 == len(logs) assert len(logs) == len(metadatas) assert self.test_message == logs[0][0][-1] @@ -190,7 +194,7 @@ def test_read_with_empty_metadata(self): # case where offset is missing but metadata not empty. self.es.delete(index=self.index_name, doc_type=self.doc_type, id=1) - logs, metadatas = self.es_task_handler.read(self.ti, 1, {'end_of_log': False}) + logs, metadatas = self.es_task_handler.read(ti, 1, {'end_of_log': False}) assert 1 == len(logs) assert len(logs) == len(metadatas) assert [[]] == logs @@ -201,12 +205,12 @@ def test_read_with_empty_metadata(self): # if not last_log_timestamp is provided. assert timezone.parse(metadatas[0]['last_log_timestamp']) > ts - def test_read_timeout(self): + def test_read_timeout(self, ti): ts = pendulum.now().subtract(minutes=5) self.es.delete(index=self.index_name, doc_type=self.doc_type, id=1) logs, metadatas = self.es_task_handler.read( - self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} + ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} ) assert 1 == len(logs) assert len(logs) == len(metadatas) @@ -216,10 +220,10 @@ def test_read_timeout(self): assert '0' == metadatas[0]['offset'] assert timezone.parse(metadatas[0]['last_log_timestamp']) == ts - def test_read_as_download_logs(self): + def test_read_as_download_logs(self, ti): ts = pendulum.now() logs, metadatas = self.es_task_handler.read( - self.ti, + ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'download_logs': True, 'end_of_log': False}, ) @@ -232,11 +236,11 @@ def test_read_as_download_logs(self): assert '1' == metadatas[0]['offset'] assert timezone.parse(metadatas[0]['last_log_timestamp']) > ts - def test_read_raises(self): + def test_read_raises(self, ti): with mock.patch.object(self.es_task_handler.log, 'exception') as mock_exception: with mock.patch("elasticsearch_dsl.Search.execute") as mock_execute: mock_execute.side_effect = Exception('Failed to read') - logs, metadatas = self.es_task_handler.read(self.ti, 1) + logs, metadatas = self.es_task_handler.read(ti, 1) assert mock_exception.call_count == 1 args, kwargs = mock_exception.call_args assert "Could not read log with log_id:" in args[0] @@ -247,18 +251,18 @@ def test_read_raises(self): assert not metadatas[0]['end_of_log'] assert '0' == metadatas[0]['offset'] - def test_set_context(self): - self.es_task_handler.set_context(self.ti) + def test_set_context(self, ti): + self.es_task_handler.set_context(ti) assert self.es_task_handler.mark_end_on_close - def test_set_context_w_json_format_and_write_stdout(self): + def test_set_context_w_json_format_and_write_stdout(self, ti): formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') self.es_task_handler.formatter = formatter self.es_task_handler.write_stdout = True self.es_task_handler.json_format = True - self.es_task_handler.set_context(self.ti) + self.es_task_handler.set_context(ti) - def test_read_with_json_format(self): + def test_read_with_json_format(self, ti): ts = pendulum.now() formatter = logging.Formatter( '[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s - %(exc_text)s' @@ -275,15 +279,15 @@ def test_read_with_json_format(self): 'lineno': 851, 'levelname': 'INFO', } - self.es_task_handler.set_context(self.ti) + self.es_task_handler.set_context(ti) self.es.index(index=self.index_name, doc_type=self.doc_type, body=self.body, id=id) logs, _ = self.es_task_handler.read( - self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} + ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} ) assert "[2020-12-24 19:25:00,962] {taskinstance.py:851} INFO - some random stuff - " == logs[0][0][1] - def test_read_with_json_format_with_custom_offset_and_host_fields(self): + def test_read_with_json_format_with_custom_offset_and_host_fields(self, ti): ts = pendulum.now() formatter = logging.Formatter( '[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s - %(exc_text)s' @@ -303,15 +307,15 @@ def test_read_with_json_format_with_custom_offset_and_host_fields(self): 'lineno': 851, 'levelname': 'INFO', } - self.es_task_handler.set_context(self.ti) + self.es_task_handler.set_context(ti) self.es.index(index=self.index_name, doc_type=self.doc_type, body=self.body, id=id) logs, _ = self.es_task_handler.read( - self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} + ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} ) assert "[2020-12-24 19:25:00,962] {taskinstance.py:851} INFO - some random stuff - " == logs[0][0][1] - def test_read_with_custom_offset_and_host_fields(self): + def test_read_with_custom_offset_and_host_fields(self, ti): ts = pendulum.now() # Delete the existing log entry as it doesn't have the new offset and host fields self.es.delete(index=self.index_name, doc_type=self.doc_type, id=1) @@ -328,15 +332,15 @@ def test_read_with_custom_offset_and_host_fields(self): self.es.index(index=self.index_name, doc_type=self.doc_type, body=self.body, id=id) logs, _ = self.es_task_handler.read( - self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} + ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} ) assert self.test_message == logs[0][0][1] - def test_close(self): + def test_close(self, ti): formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') self.es_task_handler.formatter = formatter - self.es_task_handler.set_context(self.ti) + self.es_task_handler.set_context(ti) self.es_task_handler.close() with open( os.path.join(self.local_log_location, self.filename_template.format(try_number=1)) @@ -348,9 +352,9 @@ def test_close(self): assert self.end_of_log_mark.strip() == log_line assert self.es_task_handler.closed - def test_close_no_mark_end(self): - self.ti.raw = True - self.es_task_handler.set_context(self.ti) + def test_close_no_mark_end(self, ti): + ti.raw = True + self.es_task_handler.set_context(ti) self.es_task_handler.close() with open( os.path.join(self.local_log_location, self.filename_template.format(try_number=1)) @@ -358,17 +362,17 @@ def test_close_no_mark_end(self): assert self.end_of_log_mark not in log_file.read() assert self.es_task_handler.closed - def test_close_closed(self): + def test_close_closed(self, ti): self.es_task_handler.closed = True - self.es_task_handler.set_context(self.ti) + self.es_task_handler.set_context(ti) self.es_task_handler.close() with open( os.path.join(self.local_log_location, self.filename_template.format(try_number=1)) ) as log_file: assert 0 == len(log_file.read()) - def test_close_with_no_handler(self): - self.es_task_handler.set_context(self.ti) + def test_close_with_no_handler(self, ti): + self.es_task_handler.set_context(ti) self.es_task_handler.handler = None self.es_task_handler.close() with open( @@ -377,8 +381,8 @@ def test_close_with_no_handler(self): assert 0 == len(log_file.read()) assert self.es_task_handler.closed - def test_close_with_no_stream(self): - self.es_task_handler.set_context(self.ti) + def test_close_with_no_stream(self, ti): + self.es_task_handler.set_context(ti) self.es_task_handler.handler.stream = None self.es_task_handler.close() with open( @@ -387,7 +391,7 @@ def test_close_with_no_stream(self): assert self.end_of_log_mark in log_file.read() assert self.es_task_handler.closed - self.es_task_handler.set_context(self.ti) + self.es_task_handler.set_context(ti) self.es_task_handler.handler.stream.close() self.es_task_handler.close() with open( @@ -396,17 +400,18 @@ def test_close_with_no_stream(self): assert self.end_of_log_mark in log_file.read() assert self.es_task_handler.closed - def test_render_log_id(self): - assert self.LOG_ID == self.es_task_handler._render_log_id(self.ti, 1) + def test_render_log_id(self, ti): + assert self.LOG_ID == self.es_task_handler._render_log_id(ti, 1) self.es_task_handler.json_format = True - assert self.JSON_LOG_ID == self.es_task_handler._render_log_id(self.ti, 1) + assert self.JSON_LOG_ID == self.es_task_handler._render_log_id(ti, 1) def test_clean_execution_date(self): clean_execution_date = self.es_task_handler._clean_execution_date(datetime(2016, 7, 8, 9, 10, 11, 12)) assert '2016_07_08T09_10_11_000012' == clean_execution_date - @parameterized.expand( + @pytest.mark.parametrize( + "json_format, es_frontend, expected_url", [ # Common cases (True, 'localhost:5601/{log_id}', 'https://localhost:5601/' + quote(JSON_LOG_ID)), @@ -417,9 +422,9 @@ def test_clean_execution_date(self): (False, 'https://localhost:5601/path/{log_id}', 'https://localhost:5601/path/' + quote(LOG_ID)), (False, 'http://localhost:5601/path/{log_id}', 'http://localhost:5601/path/' + quote(LOG_ID)), (False, 'other://localhost:5601/path/{log_id}', 'other://localhost:5601/path/' + quote(LOG_ID)), - ] + ], ) - def test_get_external_log_url(self, json_format, es_frontend, expected_url): + def test_get_external_log_url(self, ti, json_format, es_frontend, expected_url): es_task_handler = ElasticsearchTaskHandler( self.local_log_location, self.filename_template, @@ -432,21 +437,22 @@ def test_get_external_log_url(self, json_format, es_frontend, expected_url): self.offset_field, frontend=es_frontend, ) - url = es_task_handler.get_external_log_url(self.ti, self.ti.try_number) + url = es_task_handler.get_external_log_url(ti, ti.try_number) assert expected_url == url - @parameterized.expand( + @pytest.mark.parametrize( + "frontend, expected", [ ('localhost:5601/{log_id}', True), (None, False), - ] + ], ) def test_supports_external_link(self, frontend, expected): self.es_task_handler.frontend = frontend assert self.es_task_handler.supports_external_link == expected @mock.patch('sys.__stdout__', new_callable=io.StringIO) - def test_dynamic_offset(self, stdout_mock): + def test_dynamic_offset(self, stdout_mock, ti): # arrange handler = ElasticsearchTaskHandler( base_log_folder=self.local_log_location, @@ -465,19 +471,19 @@ def test_dynamic_offset(self, stdout_mock): logger.handlers = [handler] logger.propagate = False - self.ti._log = logger - handler.set_context(self.ti) + ti._log = logger + handler.set_context(ti) t1 = pendulum.naive(year=2017, month=1, day=1, hour=1, minute=1, second=15) t2, t3 = t1 + pendulum.duration(seconds=5), t1 + pendulum.duration(seconds=10) # act with freezegun.freeze_time(t1): - self.ti.log.info("Test") + ti.log.info("Test") with freezegun.freeze_time(t2): - self.ti.log.info("Test2") + ti.log.info("Test2") with freezegun.freeze_time(t3): - self.ti.log.info("Test3") + ti.log.info("Test3") # assert first_log, second_log, third_log = map(json.loads, stdout_mock.getvalue().strip().split("\n")) diff --git a/tests/providers/google/cloud/log/test_gcs_task_handler.py b/tests/providers/google/cloud/log/test_gcs_task_handler.py index dcf372066becc..6517be8f31245 100644 --- a/tests/providers/google/cloud/log/test_gcs_task_handler.py +++ b/tests/providers/google/cloud/log/test_gcs_task_handler.py @@ -15,45 +15,49 @@ # specific language governing permissions and limitations # under the License. import logging -import shutil import tempfile -import unittest -from datetime import datetime from unittest import mock -from airflow.models import TaskInstance -from airflow.models.dag import DAG -from airflow.operators.dummy import DummyOperator +import pytest + from airflow.providers.google.cloud.log.gcs_task_handler import GCSTaskHandler -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState +from airflow.utils.timezone import datetime from tests.test_utils.config import conf_vars -from tests.test_utils.db import clear_db_runs - - -class TestGCSTaskHandler(unittest.TestCase): - def setUp(self) -> None: - date = datetime(2020, 1, 1) - self.gcs_log_folder = "test" - self.logger = logging.getLogger("logger") - self.dag = DAG("dag_for_testing_task_handler", start_date=date) - task = DummyOperator(task_id="task_for_testing_gcs_task_handler") - self.ti = TaskInstance(task=task, execution_date=date) - self.ti.try_number = 1 - self.ti.state = State.RUNNING +from tests.test_utils.db import clear_db_dags, clear_db_runs + + +class TestGCSTaskHandler: + @pytest.fixture(autouse=True) + def task_instance(self, create_task_instance): + self.ti = ti = create_task_instance( + dag_id="dag_for_testing_gcs_task_handler", + task_id="task_for_testing_gcs_task_handler", + execution_date=datetime(2020, 1, 1), + state=TaskInstanceState.RUNNING, + ) + ti.try_number = 1 + ti.raw = False + yield + clear_db_runs() + clear_db_dags() + + @pytest.fixture(autouse=True) + def local_log_location(self): + with tempfile.TemporaryDirectory() as td: + self.local_log_location = td + yield td + + @pytest.fixture(autouse=True) + def gcs_task_handler(self, local_log_location): self.remote_log_base = "gs://bucket/remote/log/location" - self.remote_log_location = "gs://my-bucket/path/to/1.log" - self.local_log_location = tempfile.mkdtemp() self.filename_template = "{try_number}.log" - self.addCleanup(self.dag.clear) self.gcs_task_handler = GCSTaskHandler( - base_log_folder=self.local_log_location, + base_log_folder=local_log_location, gcs_log_folder=self.remote_log_base, filename_template=self.filename_template, ) - - def tearDown(self) -> None: - clear_db_runs() - shutil.rmtree(self.local_log_location, ignore_errors=True) + yield self.gcs_task_handler @mock.patch( "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id", @@ -147,7 +151,8 @@ def test_write_to_remote_on_close(self, mock_blob, mock_client, mock_creds): ) @mock.patch("google.cloud.storage.Client") @mock.patch("google.cloud.storage.Blob") - def test_failed_write_to_remote_on_close(self, mock_blob, mock_client, mock_creds): + def test_failed_write_to_remote_on_close(self, mock_blob, mock_client, mock_creds, caplog): + caplog.at_level(logging.ERROR, logger=self.gcs_task_handler.log.name) mock_blob.from_string.return_value.upload_from_string.side_effect = Exception("Failed to connect") mock_blob.from_string.return_value.download_as_bytes.return_value = b"Old log" @@ -163,12 +168,14 @@ def test_failed_write_to_remote_on_close(self, mock_blob, mock_client, mock_cred exc_info=None, ) ) - with self.assertLogs(self.gcs_task_handler.log) as cm: - self.gcs_task_handler.close() + self.gcs_task_handler.close() - assert cm.output == [ - 'ERROR:airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler:Could ' - 'not write logs to gs://bucket/remote/log/location/1.log: Failed to connect', + assert caplog.record_tuples == [ + ( + "airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler", + logging.ERROR, + "Could not write logs to gs://bucket/remote/log/location/1.log: Failed to connect", + ), ] mock_blob.assert_has_calls( [ diff --git a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py index 4cd61805ebbbc..eec743ed2df5d 100644 --- a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py +++ b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py @@ -16,19 +16,17 @@ # under the License. import logging -import unittest -from datetime import datetime from unittest import mock from urllib.parse import parse_qs, urlparse +import pytest from google.cloud.logging import Resource from google.cloud.logging_v2.types import ListLogEntriesRequest, ListLogEntriesResponse, LogEntry -from airflow.models import TaskInstance -from airflow.models.dag import DAG -from airflow.operators.dummy import DummyOperator from airflow.providers.google.cloud.log.stackdriver_task_handler import StackdriverTaskHandler -from airflow.utils.state import State +from airflow.utils import timezone +from airflow.utils.state import TaskInstanceState +from tests.test_utils.db import clear_db_dags, clear_db_runs def _create_list_log_entries_response_mock(messages, token): @@ -37,7 +35,9 @@ def _create_list_log_entries_response_mock(messages, token): ) -def _remove_stackdriver_handlers(): +@pytest.fixture() +def clean_stackdriver_handlers(): + yield for handler_ref in reversed(logging._handlerList[:]): handler = handler_ref() if not isinstance(handler, StackdriverTaskHandler): @@ -46,58 +46,67 @@ def _remove_stackdriver_handlers(): del handler -class TestStackdriverLoggingHandlerStandalone(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id') - @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client') - def test_should_pass_message_to_client(self, mock_client, mock_get_creds_and_project_id): - self.addCleanup(_remove_stackdriver_handlers) +@pytest.mark.usefixtures("clean_stackdriver_handlers") +@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id') +@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client') +def test_should_pass_message_to_client(mock_client, mock_get_creds_and_project_id): + mock_get_creds_and_project_id.return_value = ('creds', 'project_id') - mock_get_creds_and_project_id.return_value = ('creds', 'project_id') + transport_type = mock.MagicMock() + stackdriver_task_handler = StackdriverTaskHandler(transport=transport_type, labels={"key": 'value'}) + logger = logging.getLogger("logger") + logger.addHandler(stackdriver_task_handler) - transport_type = mock.MagicMock() - stackdriver_task_handler = StackdriverTaskHandler(transport=transport_type, labels={"key": 'value'}) - logger = logging.getLogger("logger") - logger.addHandler(stackdriver_task_handler) + logger.info("test-message") + stackdriver_task_handler.flush() - logger.info("test-message") - stackdriver_task_handler.flush() + transport_type.assert_called_once_with(mock_client.return_value, 'airflow') + transport_type.return_value.send.assert_called_once_with( + mock.ANY, 'test-message', labels={"key": 'value'}, resource=Resource(type='global', labels={}) + ) + mock_client.assert_called_once_with(credentials='creds', client_info=mock.ANY, project="project_id") - transport_type.assert_called_once_with(mock_client.return_value, 'airflow') - transport_type.return_value.send.assert_called_once_with( - mock.ANY, 'test-message', labels={"key": 'value'}, resource=Resource(type='global', labels={}) - ) - mock_client.assert_called_once_with(credentials='creds', client_info=mock.ANY, project="project_id") +class TestStackdriverLoggingHandlerTask: + DAG_ID = "dag_for_testing_stackdriver_file_task_handler" + TASK_ID = "task_for_testing_stackdriver_task_handler" + + @pytest.fixture(autouse=True) + def task_instance(self, create_task_instance, clean_stackdriver_handlers): + self.ti = create_task_instance( + dag_id=self.DAG_ID, + task_id=self.TASK_ID, + execution_date=timezone.datetime(2016, 1, 1), + state=TaskInstanceState.RUNNING, + ) + self.ti.try_number = 1 + self.ti.raw = False + yield + clear_db_runs() + clear_db_dags() -class TestStackdriverLoggingHandlerTask(unittest.TestCase): - def setUp(self) -> None: + def _setup_handler(self, **handler_kwargs): self.transport_mock = mock.MagicMock() - self.stackdriver_task_handler = StackdriverTaskHandler(transport=self.transport_mock) + handler_kwargs = {"transport": self.transport_mock, **handler_kwargs} + stackdriver_task_handler = StackdriverTaskHandler(**handler_kwargs) self.logger = logging.getLogger("logger") - - date = datetime(2016, 1, 1) - self.dag = DAG('dag_for_testing_file_task_handler', start_date=date) - task = DummyOperator(task_id='task_for_testing_file_log_handler', dag=self.dag) - self.ti = TaskInstance(task=task, execution_date=date) - self.ti.try_number = 1 - self.ti.state = State.RUNNING - self.addCleanup(self.dag.clear) - self.addCleanup(_remove_stackdriver_handlers) + self.logger.addHandler(stackdriver_task_handler) + return stackdriver_task_handler @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id') @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client') def test_should_set_labels(self, mock_client, mock_get_creds_and_project_id): mock_get_creds_and_project_id.return_value = ('creds', 'project_id') - self.stackdriver_task_handler.set_context(self.ti) - self.logger.addHandler(self.stackdriver_task_handler) + stackdriver_task_handler = self._setup_handler() + stackdriver_task_handler.set_context(self.ti) self.logger.info("test-message") - self.stackdriver_task_handler.flush() + stackdriver_task_handler.flush() labels = { - 'task_id': 'task_for_testing_file_log_handler', - 'dag_id': 'dag_for_testing_file_task_handler', + 'task_id': self.TASK_ID, + 'dag_id': self.DAG_ID, 'execution_date': '2016-01-01T00:00:00+00:00', 'try_number': '1', } @@ -110,18 +119,18 @@ def test_should_set_labels(self, mock_client, mock_get_creds_and_project_id): @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client') def test_should_append_labels(self, mock_client, mock_get_creds_and_project_id): mock_get_creds_and_project_id.return_value = ('creds', 'project_id') - self.stackdriver_task_handler = StackdriverTaskHandler( - transport=self.transport_mock, labels={"product.googleapis.com/task_id": "test-value"} + + stackdriver_task_handler = self._setup_handler( + labels={"product.googleapis.com/task_id": "test-value"}, ) - self.stackdriver_task_handler.set_context(self.ti) - self.logger.addHandler(self.stackdriver_task_handler) + stackdriver_task_handler.set_context(self.ti) self.logger.info("test-message") - self.stackdriver_task_handler.flush() + stackdriver_task_handler.flush() labels = { - 'task_id': 'task_for_testing_file_log_handler', - 'dag_id': 'dag_for_testing_file_task_handler', + 'task_id': self.TASK_ID, + 'dag_id': self.DAG_ID, 'execution_date': '2016-01-01T00:00:00+00:00', 'try_number': '1', 'product.googleapis.com/task_id': 'test-value', @@ -139,15 +148,16 @@ def test_should_read_logs_for_all_try(self, mock_client, mock_get_creds_and_proj ) mock_get_creds_and_project_id.return_value = ('creds', 'project_id') - logs, metadata = self.stackdriver_task_handler.read(self.ti) + stackdriver_task_handler = self._setup_handler() + logs, metadata = stackdriver_task_handler.read(self.ti) mock_client.return_value.list_log_entries.assert_called_once_with( request=ListLogEntriesRequest( resource_names=["projects/project_id"], filter=( 'resource.type="global"\n' 'logName="projects/project_id/logs/airflow"\n' - 'labels.task_id="task_for_testing_file_log_handler"\n' - 'labels.dag_id="dag_for_testing_file_task_handler"\n' + 'labels.task_id="task_for_testing_stackdriver_task_handler"\n' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' 'labels.execution_date="2016-01-01T00:00:00+00:00"' ), order_by='timestamp asc', @@ -165,8 +175,11 @@ def test_should_read_logs_for_task_with_quote(self, mock_client, mock_get_creds_ [_create_list_log_entries_response_mock(["MSG1", "MSG2"], None)] ) mock_get_creds_and_project_id.return_value = ('creds', 'project_id') + self.ti.task_id = "K\"OT" - logs, metadata = self.stackdriver_task_handler.read(self.ti) + stackdriver_task_handler = self._setup_handler() + + logs, metadata = stackdriver_task_handler.read(self.ti) mock_client.return_value.list_log_entries.assert_called_once_with( request=ListLogEntriesRequest( resource_names=["projects/project_id"], @@ -174,7 +187,7 @@ def test_should_read_logs_for_task_with_quote(self, mock_client, mock_get_creds_ 'resource.type="global"\n' 'logName="projects/project_id/logs/airflow"\n' 'labels.task_id="K\\"OT"\n' - 'labels.dag_id="dag_for_testing_file_task_handler"\n' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' 'labels.execution_date="2016-01-01T00:00:00+00:00"' ), order_by='timestamp asc', @@ -192,16 +205,17 @@ def test_should_read_logs_for_single_try(self, mock_client, mock_get_creds_and_p [_create_list_log_entries_response_mock(["MSG1", "MSG2"], None)] ) mock_get_creds_and_project_id.return_value = ('creds', 'project_id') + stackdriver_task_handler = self._setup_handler() - logs, metadata = self.stackdriver_task_handler.read(self.ti, 3) + logs, metadata = stackdriver_task_handler.read(self.ti, 3) mock_client.return_value.list_log_entries.assert_called_once_with( request=ListLogEntriesRequest( resource_names=["projects/project_id"], filter=( 'resource.type="global"\n' 'logName="projects/project_id/logs/airflow"\n' - 'labels.task_id="task_for_testing_file_log_handler"\n' - 'labels.dag_id="dag_for_testing_file_task_handler"\n' + 'labels.task_id="task_for_testing_stackdriver_task_handler"\n' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' 'labels.execution_date="2016-01-01T00:00:00+00:00"\n' 'labels.try_number="3"' ), @@ -221,15 +235,17 @@ def test_should_read_logs_with_pagination(self, mock_client, mock_get_creds_and_ mock.MagicMock(pages=iter([_create_list_log_entries_response_mock(["MSG3", "MSG4"], None)])), ] mock_get_creds_and_project_id.return_value = ('creds', 'project_id') - logs, metadata1 = self.stackdriver_task_handler.read(self.ti, 3) + stackdriver_task_handler = self._setup_handler() + + logs, metadata1 = stackdriver_task_handler.read(self.ti, 3) mock_client.return_value.list_log_entries.assert_called_once_with( request=ListLogEntriesRequest( resource_names=["projects/project_id"], filter=( '''resource.type="global" logName="projects/project_id/logs/airflow" -labels.task_id="task_for_testing_file_log_handler" -labels.dag_id="dag_for_testing_file_task_handler" +labels.task_id="task_for_testing_stackdriver_task_handler" +labels.dag_id="dag_for_testing_stackdriver_file_task_handler" labels.execution_date="2016-01-01T00:00:00+00:00" labels.try_number="3"''' ), @@ -242,7 +258,7 @@ def test_should_read_logs_with_pagination(self, mock_client, mock_get_creds_and_ assert [{'end_of_log': False, 'next_page_token': 'TOKEN1'}] == metadata1 mock_client.return_value.list_log_entries.return_value.next_page_token = None - logs, metadata2 = self.stackdriver_task_handler.read(self.ti, 3, metadata1[0]) + logs, metadata2 = stackdriver_task_handler.read(self.ti, 3, metadata1[0]) mock_client.return_value.list_log_entries.assert_called_with( request=ListLogEntriesRequest( @@ -250,8 +266,8 @@ def test_should_read_logs_with_pagination(self, mock_client, mock_get_creds_and_ filter=( 'resource.type="global"\n' 'logName="projects/project_id/logs/airflow"\n' - 'labels.task_id="task_for_testing_file_log_handler"\n' - 'labels.dag_id="dag_for_testing_file_task_handler"\n' + 'labels.task_id="task_for_testing_stackdriver_task_handler"\n' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' 'labels.execution_date="2016-01-01T00:00:00+00:00"\n' 'labels.try_number="3"' ), @@ -272,7 +288,8 @@ def test_should_read_logs_with_download(self, mock_client, mock_get_creds_and_pr ] mock_get_creds_and_project_id.return_value = ('creds', 'project_id') - logs, metadata1 = self.stackdriver_task_handler.read(self.ti, 3, {'download_logs': True}) + stackdriver_task_handler = self._setup_handler() + logs, metadata1 = stackdriver_task_handler.read(self.ti, 3, {'download_logs': True}) assert [(('default-hostname', 'MSG1\nMSG2\nMSG3\nMSG4'),)] == logs assert [{'end_of_log': True}] == metadata1 @@ -289,15 +306,13 @@ def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_cred "project_id": "project_id", }, ) - self.stackdriver_task_handler = StackdriverTaskHandler( - transport=self.transport_mock, resource=resource - ) + stackdriver_task_handler = self._setup_handler(resource=resource) entry = mock.MagicMock(json_payload={"message": "TEXT"}) page = mock.MagicMock(entries=[entry, entry], next_page_token=None) mock_client.return_value.list_log_entries.return_value.pages = (n for n in [page]) - logs, metadata = self.stackdriver_task_handler.read(self.ti) + logs, metadata = stackdriver_task_handler.read(self.ti) mock_client.return_value.list_log_entries.assert_called_once_with( request=ListLogEntriesRequest( resource_names=["projects/project_id"], @@ -307,8 +322,8 @@ def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_cred 'resource.labels."environment.name"="test-instance"\n' 'resource.labels.location="europe-west-3"\n' 'resource.labels.project_id="project_id"\n' - 'labels.task_id="task_for_testing_file_log_handler"\n' - 'labels.dag_id="dag_for_testing_file_task_handler"\n' + 'labels.task_id="task_for_testing_stackdriver_task_handler"\n' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' 'labels.execution_date="2016-01-01T00:00:00+00:00"' ), order_by='timestamp asc', @@ -324,10 +339,7 @@ def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_cred def test_should_use_credentials(self, mock_client, mock_get_creds_and_project_id): mock_get_creds_and_project_id.return_value = ('creds', 'project_id') - stackdriver_task_handler = StackdriverTaskHandler( - gcp_key_path="KEY_PATH", - ) - + stackdriver_task_handler = StackdriverTaskHandler(gcp_key_path="KEY_PATH") client = stackdriver_task_handler._client mock_get_creds_and_project_id.assert_called_once_with( @@ -348,10 +360,7 @@ def test_should_use_credentials(self, mock_client, mock_get_creds_and_project_id def test_should_return_valid_external_url(self, mock_client, mock_get_creds_and_project_id): mock_get_creds_and_project_id.return_value = ('creds', 'project_id') - stackdriver_task_handler = StackdriverTaskHandler( - gcp_key_path="KEY_PATH", - ) - + stackdriver_task_handler = StackdriverTaskHandler(gcp_key_path="KEY_PATH") url = stackdriver_task_handler.get_external_log_url(self.ti, self.ti.try_number) parsed_url = urlparse(url) @@ -367,7 +376,7 @@ def test_should_return_valid_external_url(self, mock_client, mock_get_creds_and_ 'resource.type="global"', 'logName="projects/project_id/logs/airflow"', f'labels.task_id="{self.ti.task_id}"', - f'labels.dag_id="{self.dag.dag_id}"', + f'labels.dag_id="{self.DAG_ID}"', f'labels.execution_date="{self.ti.execution_date.isoformat()}"', f'labels.try_number="{self.ti.try_number}"', ] diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 1576b24055478..ebcdedee7ebba 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -17,17 +17,14 @@ # under the License. import unittest -from datetime import datetime from unittest import mock from unittest.mock import MagicMock import pytest from google.cloud.exceptions import Conflict -from parameterized import parameterized -from airflow import models from airflow.exceptions import AirflowException -from airflow.models import DAG, TaskFail, TaskInstance, XCom +from airflow.models import DAG from airflow.providers.google.cloud.operators.bigquery import ( BigQueryCheckOperator, BigQueryConsoleIndexableLink, @@ -51,8 +48,8 @@ BigQueryValueCheckOperator, ) from airflow.serialization.serialized_objects import SerializedDAG -from airflow.settings import Session -from airflow.utils.session import provide_session +from airflow.utils.timezone import datetime +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags, clear_db_xcom TASK_ID = 'test-bq-generic-operator' TEST_DATASET = 'test-dataset' @@ -358,18 +355,12 @@ def test_execute(self, mock_hook): ) -class TestBigQueryOperator(unittest.TestCase): - def setUp(self): - self.dagbag = models.DagBag(dag_folder='/dev/null', include_examples=True) - self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} - self.dag = DAG(TEST_DAG_ID, default_args=self.args) - - def tearDown(self): - session = Session() - session.query(models.TaskInstance).filter_by(dag_id=TEST_DAG_ID).delete() - session.query(TaskFail).filter_by(dag_id=TEST_DAG_ID).delete() - session.commit() - session.close() +class TestBigQueryOperator: + def teardown_method(self): + clear_db_xcom() + clear_db_runs() + clear_db_serialized_dags() + clear_db_dags() @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_execute(self, mock_hook): @@ -519,14 +510,15 @@ def test_execute_bad_type(self, mock_hook): operator.execute(MagicMock()) @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') - def test_bigquery_operator_defaults(self, mock_hook): - operator = BigQueryExecuteQueryOperator( + def test_bigquery_operator_defaults(self, mock_hook, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + BigQueryExecuteQueryOperator, + dag_id=TEST_DAG_ID, task_id=TASK_ID, sql='Select * from test_table', - dag=self.dag, - default_args=self.args, schema_update_options=None, ) + operator = ti.task operator.execute(MagicMock()) mock_hook.return_value.run_query.assert_called_once_with( @@ -549,17 +541,23 @@ def test_bigquery_operator_defaults(self, mock_hook): encryption_configuration=None, ) assert isinstance(operator.sql, str) - ti = TaskInstance(task=operator, execution_date=DEFAULT_DATE) ti.render_templates() assert isinstance(ti.task.sql, str) - def test_bigquery_operator_extra_serialized_field_when_single_query(self): - with self.dag: - BigQueryExecuteQueryOperator( - task_id=TASK_ID, - sql='SELECT * FROM test_table', - ) - serialized_dag = SerializedDAG.to_dict(self.dag) + @pytest.mark.need_serialized_dag + def test_bigquery_operator_extra_serialized_field_when_single_query( + self, + dag_maker, + create_task_instance_of_operator, + ): + ti = create_task_instance_of_operator( + BigQueryExecuteQueryOperator, + dag_id=TEST_DAG_ID, + execution_date=DEFAULT_DATE, + task_id=TASK_ID, + sql='SELECT * FROM test_table', + ) + serialized_dag = dag_maker.get_serialized_data() assert "sql" in serialized_dag["dag"]["tasks"][0] dag = SerializedDAG.from_dict(serialized_dag) @@ -578,24 +576,25 @@ def test_bigquery_operator_extra_serialized_field_when_single_query(self): # Check DeSerialized version of operator link assert isinstance(list(simple_task.operator_extra_links)[0], BigQueryConsoleLink) - ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE) ti.xcom_push('job_id', 12345) - # check for positive case url = simple_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name) assert url == 'https://console.cloud.google.com/bigquery?j=12345' - # check for negative case - url2 = simple_task.get_extra_links(datetime(2017, 1, 2), BigQueryConsoleLink.name) - assert url2 == '' - - def test_bigquery_operator_extra_serialized_field_when_multiple_queries(self): - with self.dag: - BigQueryExecuteQueryOperator( - task_id=TASK_ID, - sql=['SELECT * FROM test_table', 'SELECT * FROM test_table2'], - ) - serialized_dag = SerializedDAG.to_dict(self.dag) + @pytest.mark.need_serialized_dag + def test_bigquery_operator_extra_serialized_field_when_multiple_queries( + self, + dag_maker, + create_task_instance_of_operator, + ): + ti = create_task_instance_of_operator( + BigQueryExecuteQueryOperator, + dag_id=TEST_DAG_ID, + execution_date=DEFAULT_DATE, + task_id=TASK_ID, + sql=['SELECT * FROM test_table', 'SELECT * FROM test_table2'], + ) + serialized_dag = dag_maker.get_serialized_data() assert "sql" in serialized_dag["dag"]["tasks"][0] dag = SerializedDAG.from_dict(serialized_dag) @@ -615,7 +614,6 @@ def test_bigquery_operator_extra_serialized_field_when_multiple_queries(self): # Check DeSerialized version of operator link assert isinstance(list(simple_task.operator_extra_links)[0], BigQueryConsoleIndexableLink) - ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE) job_id = ['123', '45'] ti.xcom_push(key='job_id', value=job_id) @@ -629,34 +627,31 @@ def test_bigquery_operator_extra_serialized_field_when_multiple_queries(self): DEFAULT_DATE, 'BigQuery Console #2' ) - @provide_session @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') - def test_bigquery_operator_extra_link_when_missing_job_id(self, mock_hook, session): - bigquery_task = BigQueryExecuteQueryOperator( + def test_bigquery_operator_extra_link_when_missing_job_id( + self, mock_hook, create_task_instance_of_operator + ): + bigquery_task = create_task_instance_of_operator( + BigQueryExecuteQueryOperator, + dag_id=TEST_DAG_ID, task_id=TASK_ID, sql='SELECT * FROM test_table', - dag=self.dag, - ) - self.dag.clear() - session.query(XCom).delete() + ).task assert '' == bigquery_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name) - @provide_session @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') - def test_bigquery_operator_extra_link_when_single_query(self, mock_hook, session): - bigquery_task = BigQueryExecuteQueryOperator( + def test_bigquery_operator_extra_link_when_single_query( + self, mock_hook, create_task_instance_of_operator + ): + ti = create_task_instance_of_operator( + BigQueryExecuteQueryOperator, + dag_id=TEST_DAG_ID, + execution_date=DEFAULT_DATE, task_id=TASK_ID, sql='SELECT * FROM test_table', - dag=self.dag, - ) - self.dag.clear() - session.query(XCom).delete() - - ti = TaskInstance( - task=bigquery_task, - execution_date=DEFAULT_DATE, ) + bigquery_task = ti.task job_id = '12345' ti.xcom_push(key='job_id', value=job_id) @@ -665,23 +660,18 @@ def test_bigquery_operator_extra_link_when_single_query(self, mock_hook, session DEFAULT_DATE, BigQueryConsoleLink.name ) - assert '' == bigquery_task.get_extra_links(datetime(2019, 1, 1), BigQueryConsoleLink.name) - - @provide_session @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') - def test_bigquery_operator_extra_link_when_multiple_query(self, mock_hook, session): - bigquery_task = BigQueryExecuteQueryOperator( + def test_bigquery_operator_extra_link_when_multiple_query( + self, mock_hook, create_task_instance_of_operator + ): + ti = create_task_instance_of_operator( + BigQueryExecuteQueryOperator, + dag_id=TEST_DAG_ID, + execution_date=DEFAULT_DATE, task_id=TASK_ID, sql=['SELECT * FROM test_table', 'SELECT * FROM test_table2'], - dag=self.dag, - ) - self.dag.clear() - session.query(XCom).delete() - - ti = TaskInstance( - task=bigquery_task, - execution_date=DEFAULT_DATE, ) + bigquery_task = ti.task job_id = ['123', '45'] ti.xcom_push(key='job_id', value=job_id) @@ -775,8 +765,9 @@ def test_get_db_hook( mock_get_db_hook.assert_called_once() -class TestBigQueryConnIdDeprecationWarning(unittest.TestCase): - @parameterized.expand( +class TestBigQueryConnIdDeprecationWarning: + @pytest.mark.parametrize( + "operator_class, kwargs", [ (BigQueryCheckOperator, dict(sql='Select * from test_table')), (BigQueryValueCheckOperator, dict(sql='Select * from test_table', pass_value=95)), @@ -786,7 +777,7 @@ class TestBigQueryConnIdDeprecationWarning(unittest.TestCase): (BigQueryDeleteDatasetOperator, dict(dataset_id=TEST_DATASET)), (BigQueryCreateEmptyDatasetOperator, dict(dataset_id=TEST_DATASET)), (BigQueryDeleteTableOperator, dict(deletion_dataset_table=TEST_DATASET)), - ] + ], ) def test_bigquery_conn_id_deprecation_warning(self, operator_class, kwargs): bigquery_conn_id = 'google_cloud_default' @@ -1039,7 +1030,8 @@ def test_execute_no_force_rerun(self, mock_hook, mock_md5): @mock.patch('airflow.providers.google.cloud.operators.bigquery.hashlib.md5') @pytest.mark.parametrize( "test_dag_id, expected_job_id", - [("test-dag-id-1.1", "airflow_test_dag_id_1_1_test_job_id_2020_01_23T00_00_00_hash")], + [("test-dag-id-1.1", "airflow_test_dag_id_1_1_test_job_id_2020_01_23T00_00_00_00_00_hash")], + ids=["test-dag-id-1.1"], ) def test_job_id_validity(self, mock_md5, test_dag_id, expected_job_id): hash_ = "hash" diff --git a/tests/providers/google/cloud/operators/test_cloud_build.py b/tests/providers/google/cloud/operators/test_cloud_build.py index 3d0e0f009bd9c..4496b9cc7b683 100644 --- a/tests/providers/google/cloud/operators/test_cloud_build.py +++ b/tests/providers/google/cloud/operators/test_cloud_build.py @@ -17,17 +17,15 @@ # under the License. """Tests for Google Cloud Build operators """ -import tempfile from copy import deepcopy -from datetime import datetime from unittest import TestCase, mock import pytest from parameterized import parameterized from airflow.exceptions import AirflowException -from airflow.models import DAG, TaskInstance from airflow.providers.google.cloud.operators.cloud_build import BuildProcessor, CloudBuildCreateBuildOperator +from airflow.utils.timezone import datetime TEST_CREATE_BODY = { "source": {"storageSource": {"bucket": "cloud-build-examples", "object": "node-docker-example.tar.gz"}}, @@ -209,23 +207,27 @@ def test_repo_source_replace(self, hook_mock): ) assert return_value == TEST_CREATE_BODY - def test_load_templated_yaml(self): - dag = DAG(dag_id='example_cloudbuild_operator', start_date=TEST_DEFAULT_DATE) - with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w+t') as build: - build.writelines( - """ - steps: - - name: 'ubuntu' - args: ['echo', 'Hello {{ params.name }}!'] - """ - ) - build.seek(0) - body_path = build.name - operator = CloudBuildCreateBuildOperator( - body=body_path, task_id="task-id", dag=dag, params={'name': 'airflow'} - ) - operator.prepare_template() - ti = TaskInstance(operator, TEST_DEFAULT_DATE) - ti.render_templates() - expected_body = {'steps': [{'name': 'ubuntu', 'args': ['echo', 'Hello airflow!']}]} - assert expected_body == operator.body + +def test_load_templated_yaml(tmp_path, session, dag_maker): + body_path = tmp_path.joinpath("test_load_templated.yaml") + body_path.write_text( + """ + steps: + - name: 'ubuntu' + args: ['echo', 'Hello {{ params.name }}!'] + """ + ) + + with dag_maker(dag_id='example_cloudbuild_operator', start_date=TEST_DEFAULT_DATE, session=session): + operator = CloudBuildCreateBuildOperator( + body=str(body_path), task_id="task-id", params={'name': 'airflow'} + ) + dagrun = dag_maker.create_dagrun() + (ti,) = dagrun.task_instances + ti.refresh_from_task(operator) + + operator.prepare_template() + ti.render_templates() + + expected_body = {'steps': [{'name': 'ubuntu', 'args': ['echo', 'Hello airflow!']}]} + assert expected_body == operator.body diff --git a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py index 3ea08bb6df384..f9a73e192de9d 100644 --- a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py +++ b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py @@ -28,7 +28,6 @@ from parameterized import parameterized from airflow.exceptions import AirflowException -from airflow.models import DAG, TaskInstance from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( ACCESS_KEY_ID, AWS_ACCESS_KEY, @@ -246,7 +245,7 @@ def test_verify_success(self, body): assert validated -class TestGcpStorageTransferJobCreateOperator(unittest.TestCase): +class TestGcpStorageTransferJobCreateOperator: @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) @@ -325,25 +324,23 @@ def test_job_create_multiple(self, aws_hook, gcp_hook): @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) - def test_templates(self, _): - dag_id = 'test_dag_id' - - self.dag = DAG(dag_id, default_args={'start_date': DEFAULT_DATE}) - op = CloudDataTransferServiceCreateJobOperator( + def test_templates(self, _, create_task_instance_of_operator): + dag_id = 'TestGcpStorageTransferJobCreateOperator_test_templates' + ti = create_task_instance_of_operator( + CloudDataTransferServiceCreateJobOperator, + dag_id=dag_id, body={"description": "{{ dag.dag_id }}"}, gcp_conn_id='{{ dag.dag_id }}', aws_conn_id='{{ dag.dag_id }}', task_id='task-id', - dag=self.dag, ) - ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() - assert dag_id == getattr(op, 'body')[DESCRIPTION] - assert dag_id == getattr(op, 'gcp_conn_id') - assert dag_id == getattr(op, 'aws_conn_id') + assert dag_id == getattr(ti.task, 'body')[DESCRIPTION] + assert dag_id == getattr(ti.task, 'gcp_conn_id') + assert dag_id == getattr(ti.task, 'aws_conn_id') -class TestGcpStorageTransferJobUpdateOperator(unittest.TestCase): +class TestGcpStorageTransferJobUpdateOperator: @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) @@ -373,23 +370,21 @@ def test_job_update(self, mock_hook): @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) - def test_templates(self, _): - dag_id = 'test_dag_id' - args = {'start_date': DEFAULT_DATE} - self.dag = DAG(dag_id, default_args=args) - op = CloudDataTransferServiceUpdateJobOperator( + def test_templates(self, _, create_task_instance_of_operator): + dag_id = 'TestGcpStorageTransferJobUpdateOperator_test_templates' + ti = create_task_instance_of_operator( + CloudDataTransferServiceUpdateJobOperator, + dag_id=dag_id, job_name='{{ dag.dag_id }}', body={'transferJob': {"name": "{{ dag.dag_id }}"}}, task_id='task-id', - dag=self.dag, ) - ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() - assert dag_id == getattr(op, 'body')['transferJob']['name'] - assert dag_id == getattr(op, 'job_name') + assert dag_id == getattr(ti.task, 'body')['transferJob']['name'] + assert dag_id == getattr(ti.task, 'job_name') -class TestGcpStorageTransferJobDeleteOperator(unittest.TestCase): +class TestGcpStorageTransferJobDeleteOperator: @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) @@ -416,22 +411,20 @@ def test_job_delete(self, mock_hook): @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) - def test_job_delete_with_templates(self, _): - dag_id = 'test_dag_id' - args = {'start_date': DEFAULT_DATE} - self.dag = DAG(dag_id, default_args=args) - op = CloudDataTransferServiceDeleteJobOperator( + def test_job_delete_with_templates(self, _, create_task_instance_of_operator): + dag_id = 'test_job_delete_with_templates' + ti = create_task_instance_of_operator( + CloudDataTransferServiceDeleteJobOperator, + dag_id=dag_id, job_name='{{ dag.dag_id }}', gcp_conn_id='{{ dag.dag_id }}', api_version='{{ dag.dag_id }}', task_id=TASK_ID, - dag=self.dag, ) - ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() - assert dag_id == getattr(op, 'job_name') - assert dag_id == getattr(op, 'gcp_conn_id') - assert dag_id == getattr(op, 'api_version') + assert dag_id == ti.task.job_name + assert dag_id == ti.task.gcp_conn_id + assert dag_id == ti.task.api_version @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' @@ -445,7 +438,7 @@ def test_job_delete_should_throw_ex_when_name_none(self, mock_hook): mock_hook.assert_not_called() -class TestGpcStorageTransferOperationsGetOperator(unittest.TestCase): +class TestGpcStorageTransferOperationsGetOperator: @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) @@ -471,16 +464,16 @@ def test_operation_get(self, mock_hook): @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) - def test_operation_get_with_templates(self, _): - dag_id = 'test_dag_id' - args = {'start_date': DEFAULT_DATE} - self.dag = DAG(dag_id, default_args=args) - op = CloudDataTransferServiceGetOperationOperator( - operation_name='{{ dag.dag_id }}', task_id='task-id', dag=self.dag + def test_operation_get_with_templates(self, _, create_task_instance_of_operator): + dag_id = 'test_operation_get_with_templates' + ti = create_task_instance_of_operator( + CloudDataTransferServiceGetOperationOperator, + dag_id=dag_id, + operation_name='{{ dag.dag_id }}', + task_id='task-id', ) - ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() - assert dag_id == getattr(op, 'operation_name') + assert dag_id == ti.task.operation_name @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' @@ -494,7 +487,7 @@ def test_operation_get_should_throw_ex_when_operation_name_none(self, mock_hook) mock_hook.assert_not_called() -class TestGcpStorageTransferOperationListOperator(unittest.TestCase): +class TestGcpStorageTransferOperationListOperator: @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) @@ -520,25 +513,22 @@ def test_operation_list(self, mock_hook): @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) - def test_templates(self, _): - dag_id = 'test_dag_id' - args = {'start_date': DEFAULT_DATE} - self.dag = DAG(dag_id, default_args=args) - op = CloudDataTransferServiceListOperationsOperator( + def test_templates(self, _, create_task_instance_of_operator): + dag_id = 'TestGcpStorageTransferOperationListOperator_test_templates' + ti = create_task_instance_of_operator( + CloudDataTransferServiceListOperationsOperator, + dag_id=dag_id, request_filter={"job_names": ['{{ dag.dag_id }}']}, gcp_conn_id='{{ dag.dag_id }}', task_id='task-id', - dag=self.dag, ) - ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() - assert dag_id == getattr(op, 'filter')['job_names'][0] + assert dag_id == ti.task.filter['job_names'][0] + assert dag_id == ti.task.gcp_conn_id - assert dag_id == getattr(op, 'gcp_conn_id') - -class TestGcpStorageTransferOperationsPauseOperator(unittest.TestCase): +class TestGcpStorageTransferOperationsPauseOperator: @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) @@ -562,22 +552,20 @@ def test_operation_pause(self, mock_hook): @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) - def test_operation_pause_with_templates(self, _): - dag_id = 'test_dag_id' - args = {'start_date': DEFAULT_DATE} - self.dag = DAG(dag_id, default_args=args) - op = CloudDataTransferServicePauseOperationOperator( + def test_operation_pause_with_templates(self, _, create_task_instance_of_operator): + dag_id = 'test_operation_pause_with_templates' + ti = create_task_instance_of_operator( + CloudDataTransferServicePauseOperationOperator, + dag_id=dag_id, operation_name='{{ dag.dag_id }}', gcp_conn_id='{{ dag.dag_id }}', api_version='{{ dag.dag_id }}', task_id=TASK_ID, - dag=self.dag, ) - ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() - assert dag_id == getattr(op, 'operation_name') - assert dag_id == getattr(op, 'gcp_conn_id') - assert dag_id == getattr(op, 'api_version') + assert dag_id == ti.task.operation_name + assert dag_id == ti.task.gcp_conn_id + assert dag_id == ti.task.api_version @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' @@ -591,7 +579,7 @@ def test_operation_pause_should_throw_ex_when_name_none(self, mock_hook): mock_hook.assert_not_called() -class TestGcpStorageTransferOperationsResumeOperator(unittest.TestCase): +class TestGcpStorageTransferOperationsResumeOperator: @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) @@ -618,22 +606,20 @@ def test_operation_resume(self, mock_hook): @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) - def test_operation_resume_with_templates(self, _): - dag_id = 'test_dag_id' - args = {'start_date': DEFAULT_DATE} - self.dag = DAG(dag_id, default_args=args) - op = CloudDataTransferServiceResumeOperationOperator( + def test_operation_resume_with_templates(self, _, create_task_instance_of_operator): + dag_id = 'test_operation_resume_with_templates' + ti = create_task_instance_of_operator( + CloudDataTransferServiceResumeOperationOperator, + dag_id=dag_id, operation_name='{{ dag.dag_id }}', gcp_conn_id='{{ dag.dag_id }}', api_version='{{ dag.dag_id }}', task_id=TASK_ID, - dag=self.dag, ) - ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() - assert dag_id == getattr(op, 'operation_name') - assert dag_id == getattr(op, 'gcp_conn_id') - assert dag_id == getattr(op, 'api_version') + assert dag_id == ti.task.operation_name + assert dag_id == ti.task.gcp_conn_id + assert dag_id == ti.task.api_version @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' @@ -647,7 +633,7 @@ def test_operation_resume_should_throw_ex_when_name_none(self, mock_hook): mock_hook.assert_not_called() -class TestGcpStorageTransferOperationsCancelOperator(unittest.TestCase): +class TestGcpStorageTransferOperationsCancelOperator: @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) @@ -674,22 +660,20 @@ def test_operation_cancel(self, mock_hook): @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) - def test_operation_cancel_with_templates(self, _): - dag_id = 'test_dag_id' - args = {'start_date': DEFAULT_DATE} - self.dag = DAG(dag_id, default_args=args) - op = CloudDataTransferServiceCancelOperationOperator( + def test_operation_cancel_with_templates(self, _, create_task_instance_of_operator): + dag_id = 'test_operation_cancel_with_templates' + ti = create_task_instance_of_operator( + CloudDataTransferServiceCancelOperationOperator, + dag_id=dag_id, operation_name='{{ dag.dag_id }}', gcp_conn_id='{{ dag.dag_id }}', api_version='{{ dag.dag_id }}', task_id=TASK_ID, - dag=self.dag, ) - ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() - assert dag_id == getattr(op, 'operation_name') - assert dag_id == getattr(op, 'gcp_conn_id') - assert dag_id == getattr(op, 'api_version') + assert dag_id == ti.task.operation_name + assert dag_id == ti.task.gcp_conn_id + assert dag_id == ti.task.api_version @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' @@ -703,7 +687,7 @@ def test_operation_cancel_should_throw_ex_when_name_none(self, mock_hook): mock_hook.assert_not_called() -class TestS3ToGoogleCloudStorageTransferOperator(unittest.TestCase): +class TestS3ToGoogleCloudStorageTransferOperator: def test_constructor(self): operator = CloudDataTransferServiceS3ToGCSOperator( task_id=TASK_ID, @@ -727,28 +711,24 @@ def test_constructor(self): @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) - def test_templates(self, _): - dag_id = 'test_dag_id' - args = {'start_date': DEFAULT_DATE} - self.dag = DAG(dag_id, default_args=args) - op = CloudDataTransferServiceS3ToGCSOperator( + def test_templates(self, _, create_task_instance_of_operator): + dag_id = 'TestS3ToGoogleCloudStorageTransferOperator_test_templates' + ti = create_task_instance_of_operator( + CloudDataTransferServiceS3ToGCSOperator, + dag_id=dag_id, s3_bucket='{{ dag.dag_id }}', gcs_bucket='{{ dag.dag_id }}', description='{{ dag.dag_id }}', object_conditions={'exclude_prefixes': ['{{ dag.dag_id }}']}, gcp_conn_id='{{ dag.dag_id }}', task_id=TASK_ID, - dag=self.dag, ) - ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() - assert dag_id == getattr(op, 's3_bucket') - assert dag_id == getattr(op, 'gcs_bucket') - assert dag_id == getattr(op, 'description') - - assert dag_id == getattr(op, 'object_conditions')['exclude_prefixes'][0] - - assert dag_id == getattr(op, 'gcp_conn_id') + assert dag_id == ti.task.s3_bucket + assert dag_id == ti.task.gcs_bucket + assert dag_id == ti.task.description + assert dag_id == ti.task.object_conditions['exclude_prefixes'][0] + assert dag_id == ti.task.gcp_conn_id @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' @@ -860,7 +840,7 @@ def test_execute_should_throw_ex_when_delete_job_without_wait(self, mock_aws_hoo mock_transfer_hook.assert_not_called() -class TestGoogleCloudStorageToGoogleCloudStorageTransferOperator(unittest.TestCase): +class TestGoogleCloudStorageToGoogleCloudStorageTransferOperator: def test_constructor(self): operator = CloudDataTransferServiceGCSToGCSOperator( task_id=TASK_ID, @@ -884,28 +864,24 @@ def test_constructor(self): @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' ) - def test_templates(self, _): - dag_id = 'test_dag_id' - args = {'start_date': DEFAULT_DATE} - self.dag = DAG(dag_id, default_args=args) - op = CloudDataTransferServiceGCSToGCSOperator( + def test_templates(self, _, create_task_instance_of_operator): + dag_id = 'TestGoogleCloudStorageToGoogleCloudStorageTransferOperator_test_templates' + ti = create_task_instance_of_operator( + CloudDataTransferServiceGCSToGCSOperator, + dag_id=dag_id, source_bucket='{{ dag.dag_id }}', destination_bucket='{{ dag.dag_id }}', description='{{ dag.dag_id }}', object_conditions={'exclude_prefixes': ['{{ dag.dag_id }}']}, gcp_conn_id='{{ dag.dag_id }}', task_id=TASK_ID, - dag=self.dag, ) - ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() - assert dag_id == getattr(op, 'source_bucket') - assert dag_id == getattr(op, 'destination_bucket') - assert dag_id == getattr(op, 'description') - - assert dag_id == getattr(op, 'object_conditions')['exclude_prefixes'][0] - - assert dag_id == getattr(op, 'gcp_conn_id') + assert dag_id == ti.task.source_bucket + assert dag_id == ti.task.destination_bucket + assert dag_id == ti.task.description + assert dag_id == ti.task.object_conditions['exclude_prefixes'][0] + assert dag_id == ti.task.gcp_conn_id @mock.patch( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook' diff --git a/tests/providers/google/cloud/operators/test_compute.py b/tests/providers/google/cloud/operators/test_compute.py index d8a853f35349b..2e4b169a9ba34 100644 --- a/tests/providers/google/cloud/operators/test_compute.py +++ b/tests/providers/google/cloud/operators/test_compute.py @@ -18,7 +18,6 @@ import ast -import unittest from copy import deepcopy from unittest import mock @@ -27,7 +26,6 @@ from googleapiclient.errors import HttpError from airflow.exceptions import AirflowException -from airflow.models import DAG, TaskInstance from airflow.providers.google.cloud.operators.compute import ( ComputeEngineCopyInstanceTemplateOperator, ComputeEngineInstanceGroupUpdateManagerTemplateOperator, @@ -48,7 +46,7 @@ DEFAULT_DATE = timezone.datetime(2017, 1, 1) -class TestGceInstanceStart(unittest.TestCase): +class TestGceInstanceStart: @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_instance_start(self, mock_hook): mock_hook.return_value.start_instance.return_value = True @@ -69,26 +67,24 @@ def test_instance_start(self, mock_hook): # Setting all the operator's input parameters as template dag_ids # (could be anything else) just to test if the templating works for all fields @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') - def test_instance_start_with_templates(self, _): - dag_id = 'test_dag_id' - args = {'start_date': DEFAULT_DATE} - self.dag = DAG(dag_id, default_args=args) - op = ComputeEngineStartInstanceOperator( + def test_instance_start_with_templates(self, _, create_task_instance_of_operator): + dag_id = 'test_instance_start_with_templates' + ti = create_task_instance_of_operator( + ComputeEngineStartInstanceOperator, + dag_id=dag_id, project_id='{{ dag.dag_id }}', zone='{{ dag.dag_id }}', resource_id='{{ dag.dag_id }}', gcp_conn_id='{{ dag.dag_id }}', api_version='{{ dag.dag_id }}', task_id='id', - dag=self.dag, ) - ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() - assert dag_id == getattr(op, 'project_id') - assert dag_id == getattr(op, 'zone') - assert dag_id == getattr(op, 'resource_id') - assert dag_id == getattr(op, 'gcp_conn_id') - assert dag_id == getattr(op, 'api_version') + assert dag_id == ti.task.project_id + assert dag_id == ti.task.zone + assert dag_id == ti.task.resource_id + assert dag_id == ti.task.gcp_conn_id + assert dag_id == ti.task.api_version @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_start_should_throw_ex_when_missing_project_id(self, mock_hook): @@ -129,7 +125,7 @@ def test_start_should_throw_ex_when_missing_resource_id(self, mock_hook): mock_hook.assert_not_called() -class TestGceInstanceStop(unittest.TestCase): +class TestGceInstanceStop: @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_instance_stop(self, mock_hook): op = ComputeEngineStopInstanceOperator( @@ -148,26 +144,24 @@ def test_instance_stop(self, mock_hook): # Setting all the operator's input parameters as templated dag_ids # (could be anything else) just to test if the templating works for all fields @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') - def test_instance_stop_with_templates(self, _): - dag_id = 'test_dag_id' - args = {'start_date': DEFAULT_DATE} - self.dag = DAG(dag_id, default_args=args) - op = ComputeEngineStopInstanceOperator( + def test_instance_stop_with_templates(self, _, create_task_instance_of_operator): + dag_id = 'test_instance_stop_with_templates' + ti = create_task_instance_of_operator( + ComputeEngineStopInstanceOperator, + dag_id=dag_id, project_id='{{ dag.dag_id }}', zone='{{ dag.dag_id }}', resource_id='{{ dag.dag_id }}', gcp_conn_id='{{ dag.dag_id }}', api_version='{{ dag.dag_id }}', task_id='id', - dag=self.dag, ) - ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() - assert dag_id == getattr(op, 'project_id') - assert dag_id == getattr(op, 'zone') - assert dag_id == getattr(op, 'resource_id') - assert dag_id == getattr(op, 'gcp_conn_id') - assert dag_id == getattr(op, 'api_version') + assert dag_id == ti.task.project_id + assert dag_id == ti.task.zone + assert dag_id == ti.task.resource_id + assert dag_id == ti.task.gcp_conn_id + assert dag_id == ti.task.api_version @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_stop_should_throw_ex_when_missing_project_id(self, mock_hook): @@ -216,7 +210,7 @@ def test_stop_should_throw_ex_when_missing_resource_id(self, mock_hook): mock_hook.assert_not_called() -class TestGceInstanceSetMachineType(unittest.TestCase): +class TestGceInstanceSetMachineType: @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_set_machine_type(self, mock_hook): mock_hook.return_value.set_machine_type.return_value = True @@ -240,11 +234,11 @@ def test_set_machine_type(self, mock_hook): # Setting all the operator's input parameters as templated dag_ids # (could be anything else) just to test if the templating works for all fields @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') - def test_set_machine_type_with_templates(self, _): - dag_id = 'test_dag_id' - args = {'start_date': DEFAULT_DATE} - self.dag = DAG(dag_id, default_args=args) - op = ComputeEngineSetMachineTypeOperator( + def test_set_machine_type_with_templates(self, _, create_task_instance_of_operator): + dag_id = 'test_set_machine_type_with_templates' + ti = create_task_instance_of_operator( + ComputeEngineSetMachineTypeOperator, + dag_id=dag_id, project_id='{{ dag.dag_id }}', zone='{{ dag.dag_id }}', resource_id='{{ dag.dag_id }}', @@ -252,15 +246,13 @@ def test_set_machine_type_with_templates(self, _): gcp_conn_id='{{ dag.dag_id }}', api_version='{{ dag.dag_id }}', task_id='id', - dag=self.dag, ) - ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() - assert dag_id == getattr(op, 'project_id') - assert dag_id == getattr(op, 'zone') - assert dag_id == getattr(op, 'resource_id') - assert dag_id == getattr(op, 'gcp_conn_id') - assert dag_id == getattr(op, 'api_version') + assert dag_id == ti.task.project_id + assert dag_id == ti.task.zone + assert dag_id == ti.task.resource_id + assert dag_id == ti.task.gcp_conn_id + assert dag_id == ti.task.api_version @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_set_machine_type_should_throw_ex_when_missing_project_id(self, mock_hook): @@ -474,7 +466,7 @@ def test_set_machine_type_should_handle_and_trim_gce_error( GCE_INSTANCE_TEMPLATE_BODY_GET_NEW['name'] = GCE_INSTANCE_TEMPLATE_NEW_NAME -class TestGceInstanceTemplateCopy(unittest.TestCase): +class TestGceInstanceTemplateCopy: @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_successful_copy_template(self, mock_hook): mock_hook.return_value.get_instance_template.side_effect = [ @@ -873,7 +865,7 @@ def test_missing_name(self, mock_hook): } -class TestGceInstanceGroupManagerUpdate(unittest.TestCase): +class TestGceInstanceGroupManagerUpdate: @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_successful_instance_group_update(self, mock_hook): mock_hook.return_value.get_instance_group_manager.return_value = deepcopy( diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 459865eb1979f..f8500aa9b0080 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -17,7 +17,6 @@ import inspect import unittest -from datetime import datetime from unittest import mock from unittest.mock import MagicMock, Mock, call @@ -26,7 +25,7 @@ from google.api_core.retry import Retry from airflow import AirflowException -from airflow.models import DAG, DagBag, TaskInstance +from airflow.models import DAG, DagBag from airflow.providers.google.cloud.operators.dataproc import ( ClusterGenerator, DataprocClusterLink, @@ -47,6 +46,7 @@ DataprocUpdateClusterOperator, ) from airflow.serialization.serialized_objects import SerializedDAG +from airflow.utils.timezone import datetime from airflow.version import version as airflow_version from tests.test_utils.db import clear_db_runs, clear_db_xcom @@ -570,57 +570,49 @@ def test_execute_if_cluster_exists_in_deleting_state( region=GCP_LOCATION, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME ) - def test_operator_extra_links(self): - op = DataprocCreateClusterOperator( - task_id=TASK_ID, - region=GCP_LOCATION, - project_id=GCP_PROJECT, - cluster_name=CLUSTER_NAME, - delete_on_error=True, - gcp_conn_id=GCP_CONN_ID, - dag=self.dag, - ) - - serialized_dag = SerializedDAG.to_dict(self.dag) - deserialized_dag = SerializedDAG.from_dict(serialized_dag) - deserialized_task = deserialized_dag.task_dict[TASK_ID] - # Assert operator links for serialized DAG - self.assertEqual( - serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], - [{"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}], - ) +@pytest.mark.need_serialized_dag +def test_create_cluster_operator_extra_links(dag_maker, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + DataprocCreateClusterOperator, + dag_id=TEST_DAG_ID, + execution_date=DEFAULT_DATE, + task_id=TASK_ID, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_name=CLUSTER_NAME, + delete_on_error=True, + gcp_conn_id=GCP_CONN_ID, + ) - # Assert operator link types are preserved during deserialization - self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) + serialized_dag = dag_maker.get_serialized_data() + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] - ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) + # Assert operator links for serialized DAG + assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ + {"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}} + ] - # Assert operator link is empty when no XCom push occurred - self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "") + # Assert operator link types are preserved during deserialization + assert isinstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) - # Assert operator link is empty for deserialized task when no XCom push occurred - self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - "", - ) + # Assert operator link is empty when no XCom push occurred + assert ti.task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == "" - ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + # Assert operator link is empty for deserialized task when no XCom push occurred + assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == "" - # Assert operator links are preserved in deserialized tasks after execution - self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - DATAPROC_CLUSTER_LINK_EXPECTED, - ) + ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) - # Assert operator links after execution - self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - DATAPROC_CLUSTER_LINK_EXPECTED, - ) + # Assert operator links are preserved in deserialized tasks after execution + assert ( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) + == DATAPROC_CLUSTER_LINK_EXPECTED + ) - # Check negative case - self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "") + # Assert operator links after execution + assert ti.task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED class TestDataprocClusterScaleOperator(DataprocClusterTestBase): @@ -672,59 +664,51 @@ def test_execute(self, mock_hook): execution_date=None, ) - def test_operator_extra_links(self): - op = DataprocScaleClusterOperator( - task_id=TASK_ID, - cluster_name=CLUSTER_NAME, - project_id=GCP_PROJECT, - region=GCP_LOCATION, - num_workers=3, - num_preemptible_workers=2, - graceful_decommission_timeout="2m", - gcp_conn_id=GCP_CONN_ID, - dag=self.dag, - ) - - serialized_dag = SerializedDAG.to_dict(self.dag) - deserialized_dag = SerializedDAG.from_dict(serialized_dag) - deserialized_task = deserialized_dag.task_dict[TASK_ID] - # Assert operator links for serialized DAG - self.assertEqual( - serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], - [{"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}], - ) +@pytest.mark.need_serialized_dag +def test_scale_cluster_operator_extra_links(dag_maker, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + DataprocScaleClusterOperator, + dag_id=TEST_DAG_ID, + execution_date=DEFAULT_DATE, + task_id=TASK_ID, + cluster_name=CLUSTER_NAME, + project_id=GCP_PROJECT, + region=GCP_LOCATION, + num_workers=3, + num_preemptible_workers=2, + graceful_decommission_timeout="2m", + gcp_conn_id=GCP_CONN_ID, + ) - # Assert operator link types are preserved during deserialization - self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) + serialized_dag = dag_maker.get_serialized_data() + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] - ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) + # Assert operator links for serialized DAG + assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ + {"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}} + ] - # Assert operator link is empty when no XCom push occurred - self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "") + # Assert operator link types are preserved during deserialization + assert isinstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) - # Assert operator link is empty for deserialized task when no XCom push occurred - self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - "", - ) + # Assert operator link is empty when no XCom push occurred + assert ti.task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == "" - ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + # Assert operator link is empty for deserialized task when no XCom push occurred + assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == "" - # Assert operator links are preserved in deserialized tasks after execution - self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - DATAPROC_CLUSTER_LINK_EXPECTED, - ) + ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) - # Assert operator links after execution - self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - DATAPROC_CLUSTER_LINK_EXPECTED, - ) + # Assert operator links are preserved in deserialized tasks after execution + assert ( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) + == DATAPROC_CLUSTER_LINK_EXPECTED + ) - # Check negative case - self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "") + # Assert operator links after execution + assert ti.task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED class TestDataprocClusterDeleteOperator(unittest.TestCase): @@ -882,54 +866,6 @@ def test_on_kill(self, mock_hook): project_id=GCP_PROJECT, region=GCP_LOCATION, job_id=job_id ) - @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_operator_extra_links(self, mock_hook): - mock_hook.return_value.project_id = GCP_PROJECT - op = DataprocSubmitJobOperator( - task_id=TASK_ID, - region=GCP_LOCATION, - project_id=GCP_PROJECT, - job={}, - gcp_conn_id=GCP_CONN_ID, - dag=self.dag, - ) - - serialized_dag = SerializedDAG.to_dict(self.dag) - deserialized_dag = SerializedDAG.from_dict(serialized_dag) - deserialized_task = deserialized_dag.task_dict[TASK_ID] - - # Assert operator links for serialized_dag - self.assertEqual( - serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], - [{"airflow.providers.google.cloud.operators.dataproc.DataprocJobLink": {}}], - ) - - # Assert operator link types are preserved during deserialization - self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocJobLink) - - ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) - - # Assert operator link is empty when no XCom push occurred - self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "") - - # Assert operator link is empty for deserialized task when no XCom push occurred - self.assertEqual(deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "") - - ti.xcom_push(key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED) - - # Assert operator links are preserved in deserialized tasks - self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), - DATAPROC_JOB_LINK_EXPECTED, - ) - # Assert operator links after execution - self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), - DATAPROC_JOB_LINK_EXPECTED, - ) - # Check for negative case - self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), "") - @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_location_deprecation_warning(self, mock_hook): xcom_push_call = call.ti.xcom_push( @@ -1009,6 +945,48 @@ def test_location_deprecation_warning(self, mock_hook): op.execute(context=self.mock_context) +@pytest.mark.need_serialized_dag +@mock.patch(DATAPROC_PATH.format("DataprocHook")) +def test_submit_job_operator_extra_links(mock_hook, dag_maker, create_task_instance_of_operator): + mock_hook.return_value.project_id = GCP_PROJECT + ti = create_task_instance_of_operator( + DataprocSubmitJobOperator, + dag_id=TEST_DAG_ID, + execution_date=DEFAULT_DATE, + task_id=TASK_ID, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + job={}, + gcp_conn_id=GCP_CONN_ID, + ) + + serialized_dag = dag_maker.get_serialized_data() + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] + + # Assert operator links for serialized_dag + assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ + {"airflow.providers.google.cloud.operators.dataproc.DataprocJobLink": {}} + ] + + # Assert operator link types are preserved during deserialization + assert isinstance(deserialized_task.operator_extra_links[0], DataprocJobLink) + + # Assert operator link is empty when no XCom push occurred + assert ti.task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == "" + + # Assert operator link is empty for deserialized task when no XCom push occurred + assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == "" + + ti.xcom_push(key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED) + + # Assert operator links are preserved in deserialized tasks + assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == DATAPROC_JOB_LINK_EXPECTED + + # Assert operator links after execution + assert ti.task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == DATAPROC_JOB_LINK_EXPECTED + + class TestDataprocUpdateClusterOperator(DataprocClusterTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): @@ -1059,58 +1037,6 @@ def test_execute(self, mock_hook): execution_date=None, ) - def test_operator_extra_links(self): - op = DataprocUpdateClusterOperator( - task_id=TASK_ID, - region=GCP_LOCATION, - cluster_name=CLUSTER_NAME, - cluster=CLUSTER, - update_mask=UPDATE_MASK, - graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, - project_id=GCP_PROJECT, - gcp_conn_id=GCP_CONN_ID, - dag=self.dag, - ) - - serialized_dag = SerializedDAG.to_dict(self.dag) - deserialized_dag = SerializedDAG.from_dict(serialized_dag) - deserialized_task = deserialized_dag.task_dict[TASK_ID] - - # Assert operator links for serialized_dag - self.assertEqual( - serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], - [{"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}], - ) - - # Assert operator link types are preserved during deserialization - self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) - - ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) - - # Assert operator link is empty when no XCom push occurred - self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "") - - # Assert operator link is empty for deserialized task when no XCom push occurred - self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - "", - ) - - ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) - - # Assert operator links are preserved in deserialized tasks - self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - DATAPROC_CLUSTER_LINK_EXPECTED, - ) - # Assert operator links after execution - self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - DATAPROC_CLUSTER_LINK_EXPECTED, - ) - # Check for negative case - self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "") - @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_location_deprecation_warning(self, mock_hook): self.extra_links_manager_mock.attach_mock(mock_hook, 'hook') @@ -1186,6 +1112,52 @@ def test_location_deprecation_warning(self, mock_hook): op.execute(context=self.mock_context) +@pytest.mark.need_serialized_dag +def test_update_cluster_operator_extra_links(dag_maker, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + DataprocUpdateClusterOperator, + dag_id=TEST_DAG_ID, + execution_date=DEFAULT_DATE, + task_id=TASK_ID, + region=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + cluster=CLUSTER, + update_mask=UPDATE_MASK, + graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, + project_id=GCP_PROJECT, + gcp_conn_id=GCP_CONN_ID, + ) + + serialized_dag = dag_maker.get_serialized_data() + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] + + # Assert operator links for serialized_dag + assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ + {"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}} + ] + + # Assert operator link types are preserved during deserialization + assert isinstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) + + # Assert operator link is empty when no XCom push occurred + assert ti.task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == "" + + # Assert operator link is empty for deserialized task when no XCom push occurred + assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == "" + + ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + + # Assert operator links are preserved in deserialized tasks + assert ( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) + == DATAPROC_CLUSTER_LINK_EXPECTED + ) + + # Assert operator links after execution + assert ti.task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED + + class TestDataprocWorkflowTemplateInstantiateOperator(unittest.TestCase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): @@ -1509,59 +1481,49 @@ def test_execute(self, mock_hook, mock_uuid): # Test whether xcom push occurs before polling for job self.extra_links_manager_mock.assert_has_calls(self.extra_links_expected_calls, any_order=False) - @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_operator_extra_links(self, mock_hook): - mock_hook.return_value.project_id = GCP_PROJECT - op = DataprocSubmitSparkJobOperator( - task_id=TASK_ID, - region=GCP_LOCATION, - gcp_conn_id=GCP_CONN_ID, - main_class=self.main_class, - dataproc_jars=self.jars, - dag=self.dag, - ) - - serialized_dag = SerializedDAG.to_dict(self.dag) - deserialized_dag = SerializedDAG.from_dict(serialized_dag) - deserialized_task = deserialized_dag.task_dict[TASK_ID] +@pytest.mark.need_serialized_dag +@mock.patch(DATAPROC_PATH.format("DataprocHook")) +def test_submit_spark_job_operator_extra_links(mock_hook, dag_maker, create_task_instance_of_operator): + mock_hook.return_value.project_id = GCP_PROJECT - # Assert operator links for serialized DAG - self.assertEqual( - serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], - [{"airflow.providers.google.cloud.operators.dataproc.DataprocJobLink": {}}], - ) + ti = create_task_instance_of_operator( + DataprocSubmitSparkJobOperator, + dag_id=TEST_DAG_ID, + execution_date=DEFAULT_DATE, + task_id=TASK_ID, + region=GCP_LOCATION, + gcp_conn_id=GCP_CONN_ID, + main_class="org.apache.spark.examples.SparkPi", + dataproc_jars=["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + ) - # Assert operator link types are preserved during deserialization - self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocJobLink) + serialized_dag = dag_maker.get_serialized_data() + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] - ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) + # Assert operator links for serialized DAG + assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ + {"airflow.providers.google.cloud.operators.dataproc.DataprocJobLink": {}} + ] - # Assert operator link is empty when no XCom push occurred - self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "") + # Assert operator link types are preserved during deserialization + assert isinstance(deserialized_task.operator_extra_links[0], DataprocJobLink) - # Assert operator link is empty for deserialized task when no XCom push occurred - self.assertEqual(deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "") + # Assert operator link is empty when no XCom push occurred + assert ti.task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == "" - ti.xcom_push(key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED) + # Assert operator link is empty for deserialized task when no XCom push occurred + assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == "" - # Assert operator links after task execution - self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), - DATAPROC_JOB_LINK_EXPECTED, - ) + ti.xcom_push(key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED) - # Assert operator links are preserved in deserialized tasks - self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), - DATAPROC_JOB_LINK_EXPECTED, - ) + # Assert operator links after task execution + assert ti.task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == DATAPROC_JOB_LINK_EXPECTED - # Assert for negative case - self.assertEqual( - deserialized_task.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), - "", - ) + # Assert operator links are preserved in deserialized tasks + link = deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) + assert link == DATAPROC_JOB_LINK_EXPECTED class TestDataProcHadoopOperator(unittest.TestCase): diff --git a/tests/providers/google/cloud/operators/test_mlengine.py b/tests/providers/google/cloud/operators/test_mlengine.py index 7a226e325ea05..d316f9b6be519 100644 --- a/tests/providers/google/cloud/operators/test_mlengine.py +++ b/tests/providers/google/cloud/operators/test_mlengine.py @@ -16,7 +16,6 @@ # under the License. import copy -import datetime import unittest from unittest.mock import ANY, MagicMock, patch @@ -25,7 +24,6 @@ from googleapiclient.errors import HttpError from airflow.exceptions import AirflowException -from airflow.models import TaskInstance from airflow.models.dag import DAG from airflow.providers.google.cloud.operators.mlengine import ( AIPlatformConsoleLink, @@ -43,9 +41,10 @@ MLEngineTrainingCancelJobOperator, ) from airflow.serialization.serialized_objects import SerializedDAG +from airflow.utils import timezone from airflow.utils.dates import days_ago -DEFAULT_DATE = datetime.datetime(2017, 6, 6) +DEFAULT_DATE = timezone.datetime(2017, 6, 6) TEST_DAG_ID = "test-mlengine-operators" TEST_PROJECT_ID = "test-project-id" @@ -311,7 +310,7 @@ def test_failed_job_error(self, mock_hook): assert 'A failure message' == str(ctx.value) -class TestMLEngineStartTrainingJobOperator(unittest.TestCase): +class TestMLEngineStartTrainingJobOperator: TRAINING_DEFAULT_ARGS = { 'project_id': 'test-project', 'job_id': 'test_training', @@ -322,7 +321,6 @@ class TestMLEngineStartTrainingJobOperator(unittest.TestCase): 'scale_tier': 'STANDARD_1', 'labels': {'some': 'labels'}, 'task_id': 'test-training', - 'start_date': days_ago(1), } TRAINING_INPUT = { 'jobId': 'test_training', @@ -336,9 +334,6 @@ class TestMLEngineStartTrainingJobOperator(unittest.TestCase): }, } - def setUp(self): - self.dag = DAG(TEST_DAG_ID, default_args=self.TRAINING_DEFAULT_ARGS) - @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_success_create_training_job(self, mock_hook): success_response = self.TRAINING_INPUT.copy() @@ -571,12 +566,12 @@ def test_failed_job_error(self, mock_hook): assert 'A failure message' == str(ctx.value) @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') - def test_console_extra_link(self, mock_hook): - training_op = MLEngineStartTrainingJobOperator(**self.TRAINING_DEFAULT_ARGS) - - ti = TaskInstance( - task=training_op, + def test_console_extra_link(self, mock_hook, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineStartTrainingJobOperator, + dag_id="test_console_extra_link", execution_date=DEFAULT_DATE, + **self.TRAINING_DEFAULT_ARGS, ) job_id = self.TRAINING_DEFAULT_ARGS['job_id'] @@ -589,15 +584,18 @@ def test_console_extra_link(self, mock_hook): assert ( f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}" - == training_op.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name) + == ti.task.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name) ) - assert '' == training_op.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name) - - def test_console_extra_link_serialized_field(self): - with self.dag: - training_op = MLEngineStartTrainingJobOperator(**self.TRAINING_DEFAULT_ARGS) - serialized_dag = SerializedDAG.to_dict(self.dag) + @pytest.mark.need_serialized_dag + def test_console_extra_link_serialized_field(self, dag_maker, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineStartTrainingJobOperator, + dag_id="test_console_extra_link_serialized_field", + execution_date=DEFAULT_DATE, + **self.TRAINING_DEFAULT_ARGS, + ) + serialized_dag = dag_maker.get_serialized_data() dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict[self.TRAINING_DEFAULT_ARGS['task_id']] @@ -616,10 +614,6 @@ def test_console_extra_link_serialized_field(self): "project_id": project_id, } - ti = TaskInstance( - task=training_op, - execution_date=DEFAULT_DATE, - ) ti.xcom_push(key='gcp_metadata', value=gcp_metadata) assert ( @@ -627,8 +621,6 @@ def test_console_extra_link_serialized_field(self): == simple_task.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name) ) - assert '' == simple_task.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name) - class TestMLEngineTrainingCancelJobOperator(unittest.TestCase): diff --git a/tests/providers/http/sensors/test_http.py b/tests/providers/http/sensors/test_http.py index 23ac2fdf2f69f..3fc61bb5295a5 100644 --- a/tests/providers/http/sensors/test_http.py +++ b/tests/providers/http/sensors/test_http.py @@ -23,7 +23,6 @@ import requests from airflow.exceptions import AirflowException, AirflowSensorTimeout -from airflow.models import TaskInstance from airflow.models.dag import DAG from airflow.providers.http.operators.http import SimpleHttpOperator from airflow.providers.http.sensors.http import HttpSensor @@ -34,13 +33,9 @@ TEST_DAG_ID = 'unit_test_dag' -class TestHttpSensor(unittest.TestCase): - def setUp(self): - args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} - self.dag = DAG(TEST_DAG_ID, default_args=args) - +class TestHttpSensor: @patch("airflow.providers.http.hooks.http.requests.Session.send") - def test_poke_exception(self, mock_session_send): + def test_poke_exception(self, mock_session_send, create_task_of_operator): """ Exception occurs in poke function should not be ignored. """ @@ -51,7 +46,9 @@ def test_poke_exception(self, mock_session_send): def resp_check(_): raise AirflowException('AirflowException raised here!') - task = HttpSensor( + task = create_task_of_operator( + HttpSensor, + dag_id='http_sensor_poke_exception', task_id='http_sensor_poke_exception', http_conn_id='http_default', endpoint='', @@ -64,7 +61,11 @@ def resp_check(_): task.execute(context={}) @patch("airflow.providers.http.hooks.http.requests.Session.send") - def test_poke_continues_for_http_500_with_extra_options_check_response_false(self, mock_session_send): + def test_poke_continues_for_http_500_with_extra_options_check_response_false( + self, + mock_session_send, + create_task_of_operator, + ): def resp_check(_): return False @@ -74,8 +75,9 @@ def resp_check(_): response._content = b'Internal Server Error' mock_session_send.return_value = response - task = HttpSensor( - dag=self.dag, + task = create_task_of_operator( + HttpSensor, + dag_id='http_sensor_poke_for_code_500', task_id='http_sensor_poke_for_code_500', http_conn_id='http_default', endpoint='', @@ -87,16 +89,17 @@ def resp_check(_): poke_interval=1, ) - with self.assertRaises(AirflowSensorTimeout): + with pytest.raises(AirflowSensorTimeout): task.execute(context={}) @patch("airflow.providers.http.hooks.http.requests.Session.send") - def test_head_method(self, mock_session_send): + def test_head_method(self, mock_session_send, create_task_of_operator): def resp_check(_): return True - task = HttpSensor( - dag=self.dag, + task = create_task_of_operator( + HttpSensor, + dag_id='http_sensor_head_method', task_id='http_sensor_head_method', http_conn_id='http_default', endpoint='', @@ -109,8 +112,7 @@ def resp_check(_): task.execute(context={}) - args, kwargs = mock_session_send.call_args - received_request = args[0] + received_request = mock_session_send.call_args[0][0] prep_request = requests.Request('HEAD', 'https://www.httpbin.org', {}).prepare() @@ -118,7 +120,7 @@ def resp_check(_): assert prep_request.method, received_request.method @patch("airflow.providers.http.hooks.http.requests.Session.send") - def test_poke_context(self, mock_session_send): + def test_poke_context(self, mock_session_send, create_task_instance_of_operator): response = requests.Response() response.status_code = 200 mock_session_send.return_value = response @@ -128,7 +130,10 @@ def resp_check(_, execution_date): return True raise AirflowException('AirflowException raised here!') - task = HttpSensor( + task_instance = create_task_instance_of_operator( + HttpSensor, + dag_id='http_sensor_poke_exception', + execution_date=DEFAULT_DATE, task_id='http_sensor_poke_exception', http_conn_id='http_default', endpoint='', @@ -136,14 +141,12 @@ def resp_check(_, execution_date): response_check=resp_check, timeout=5, poke_interval=1, - dag=self.dag, ) - task_instance = TaskInstance(task=task, execution_date=DEFAULT_DATE) - task.execute(task_instance.get_template_context()) + task_instance.task.execute(task_instance.get_template_context()) @patch("airflow.providers.http.hooks.http.requests.Session.send") - def test_logging_head_error_request(self, mock_session_send): + def test_logging_head_error_request(self, mock_session_send, create_task_of_operator): def resp_check(_): return True @@ -153,9 +156,10 @@ def resp_check(_): response._content = b"This endpoint doesn't exist" mock_session_send.return_value = response - task = HttpSensor( - dag=self.dag, - task_id='http_sensor_head_method', + task = create_task_of_operator( + HttpSensor, + dag_id='http_sensor_head_error', + task_id='http_sensor_head_error', http_conn_id='http_default', endpoint='', request_params={}, diff --git a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py index c7582a13fd9bc..595d2e50d2fb1 100644 --- a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py +++ b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py @@ -15,23 +15,38 @@ # specific language governing permissions and limitations # under the License. -import unittest -from datetime import datetime from unittest import mock +import pytest from azure.common import AzureHttpError -from airflow.models import DAG, TaskInstance -from airflow.operators.dummy import DummyOperator from airflow.providers.microsoft.azure.hooks.wasb import WasbHook from airflow.providers.microsoft.azure.log.wasb_task_handler import WasbTaskHandler from airflow.utils.state import State +from airflow.utils.timezone import datetime from tests.test_utils.config import conf_vars +from tests.test_utils.db import clear_db_dags, clear_db_runs -class TestWasbTaskHandler(unittest.TestCase): - def setUp(self): - super().setUp() +class TestWasbTaskHandler: + @pytest.fixture(autouse=True) + def ti(self, create_task_instance): + date = datetime(2020, 8, 10) + ti = create_task_instance( + dag_id='dag_for_testing_wasb_task_handler', + task_id='task_for_testing_wasb_log_handler', + execution_date=date, + start_date=date, + dagrun_state=State.RUNNING, + state=State.RUNNING, + ) + ti.try_number = 1 + ti.raw = False + yield ti + clear_db_runs() + clear_db_dags() + + def setup_method(self): self.wasb_log_folder = 'wasb://container/remote/log/location' self.remote_log_location = 'remote/log/location/1.log' self.local_log_location = 'local/log/location' @@ -45,14 +60,6 @@ def setUp(self): delete_local_copy=True, ) - date = datetime(2020, 8, 10) - self.dag = DAG('dag_for_testing_file_task_handler', start_date=date) - task = DummyOperator(task_id='task_for_testing_file_log_handler', dag=self.dag) - self.ti = TaskInstance(task=task, execution_date=date) - self.ti.try_number = 1 - self.ti.state = State.RUNNING - self.addCleanup(self.dag.clear) - @conf_vars({('logging', 'remote_log_conn_id'): 'wasb_default'}) @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") def test_hook(self, mock_service): @@ -77,13 +84,13 @@ def test_hook_raises(self): exc_info=True, ) - def test_set_context_raw(self): - self.ti.raw = True - self.wasb_task_handler.set_context(self.ti) + def test_set_context_raw(self, ti): + ti.raw = True + self.wasb_task_handler.set_context(ti) assert not self.wasb_task_handler.upload_on_close - def test_set_context_not_raw(self): - self.wasb_task_handler.set_context(self.ti) + def test_set_context_not_raw(self, ti): + self.wasb_task_handler.set_context(ti) assert self.wasb_task_handler.upload_on_close @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") @@ -96,10 +103,10 @@ def test_wasb_log_exists(self, mock_hook): ) @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") - def test_wasb_read(self, mock_hook): + def test_wasb_read(self, mock_hook, ti): mock_hook.return_value.read_file.return_value = 'Log line' assert self.wasb_task_handler.wasb_read(self.remote_log_location) == "Log line" - assert self.wasb_task_handler.read(self.ti) == ( + assert self.wasb_task_handler.read(ti) == ( [ [ ( diff --git a/tests/providers/microsoft/azure/operators/test_adx.py b/tests/providers/microsoft/azure/operators/test_adx.py index d8b080b419a99..e5bb2b8e495b0 100644 --- a/tests/providers/microsoft/azure/operators/test_adx.py +++ b/tests/providers/microsoft/azure/operators/test_adx.py @@ -22,7 +22,7 @@ from azure.kusto.data._models import KustoResultTable -from airflow.models import DAG, TaskInstance +from airflow.models import DAG from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook from airflow.providers.microsoft.azure.operators.adx import AzureDataExplorerQueryOperator from airflow.utils import timezone @@ -81,10 +81,20 @@ def test_run_query(self, mock_conn, mock_run_query): MOCK_DATA['query'], MOCK_DATA['database'], MOCK_DATA['options'] ) - @mock.patch.object(AzureDataExplorerHook, 'run_query', return_value=MockResponse()) - @mock.patch.object(AzureDataExplorerHook, 'get_conn') - def test_xcom_push_and_pull(self, mock_conn, mock_run_query): - ti = TaskInstance(task=self.operator, execution_date=timezone.utcnow()) - ti.run() - assert ti.xcom_pull(task_ids=MOCK_DATA['task_id']) == str(MOCK_RESULT) +@mock.patch.object(AzureDataExplorerHook, 'run_query', return_value=MockResponse()) +@mock.patch.object(AzureDataExplorerHook, 'get_conn') +def test_azure_data_explorer_query_operator_xcom_push_and_pull( + mock_conn, + mock_run_query, + create_task_instance_of_operator, +): + ti = create_task_instance_of_operator( + AzureDataExplorerQueryOperator, + dag_id="test_azure_data_explorer_query_operator_xcom_push_and_pull", + execution_date=timezone.utcnow(), + **MOCK_DATA, + ) + ti.run() + + assert ti.xcom_pull(task_ids=MOCK_DATA["task_id"]) == str(MOCK_RESULT) diff --git a/tests/providers/qubole/operators/test_qubole.py b/tests/providers/qubole/operators/test_qubole.py index d5a812d814259..82fa57c96275f 100644 --- a/tests/providers/qubole/operators/test_qubole.py +++ b/tests/providers/qubole/operators/test_qubole.py @@ -17,11 +17,12 @@ # under the License. # -from unittest import TestCase, mock +from unittest import mock + +import pytest from airflow import settings from airflow.models import DAG, Connection -from airflow.models.taskinstance import TaskInstance from airflow.providers.qubole.hooks.qubole import QuboleHook from airflow.providers.qubole.operators.qubole import QDSLink, QuboleOperator from airflow.serialization.serialized_objects import SerializedDAG @@ -36,15 +37,15 @@ DEFAULT_DATE = datetime(2017, 1, 1) -class TestQuboleOperator(TestCase): - def setUp(self): +class TestQuboleOperator: + def setup_method(self): db.merge_conn(Connection(conn_id=DEFAULT_CONN, conn_type='HTTP')) db.merge_conn(Connection(conn_id=TEST_CONN, conn_type='HTTP', host='http://localhost/api')) - def tearDown(self): + def teardown_method(self): session = settings.Session() session.query(Connection).filter(Connection.conn_id == TEST_CONN).delete() - session.commit() + session.flush() session.close() def test_init_with_default_connection(self): @@ -60,19 +61,17 @@ def test_init_with_template_connection(self): assert task.task_id == TASK_ID assert task.qubole_conn_id == TEMPLATE_CONN - def test_init_with_template_cluster_label(self): - dag = DAG(DAG_ID, start_date=DEFAULT_DATE) - task = QuboleOperator( + def test_init_with_template_cluster_label(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + QuboleOperator, + dag_id="test_init_with_template_cluster_label", + execution_date=DEFAULT_DATE, task_id=TASK_ID, - dag=dag, cluster_label='{{ params.cluster_label }}', params={'cluster_label': 'default'}, ) - - ti = TaskInstance(task, DEFAULT_DATE) ti.render_templates() - - assert task.cluster_label == 'default' + assert ti.task.cluster_label == 'default' def test_get_hook(self): dag = DAG(DAG_ID, start_date=DEFAULT_DATE) @@ -121,61 +120,45 @@ def test_position_args_parameters(self): task.get_hook().create_cmd_args({'run_id': 'dummy'})[5] == "s3n://airflow/destination_hadoopcmd" ) - def test_get_redirect_url(self): - dag = DAG(DAG_ID, start_date=DEFAULT_DATE) - - with dag: - task = QuboleOperator( - task_id=TASK_ID, - qubole_conn_id=TEST_CONN, - command_type='shellcmd', - parameters="param1 param2", - dag=dag, - ) - - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + def test_get_redirect_url(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + QuboleOperator, + dag_id="test_get_redirect_url", + execution_date=DEFAULT_DATE, + task_id=TASK_ID, + qubole_conn_id=TEST_CONN, + command_type='shellcmd', + parameters="param1 param2", + ) ti.xcom_push('qbol_cmd_id', 12345) - # check for positive case - url = task.get_extra_links(DEFAULT_DATE, 'Go to QDS') + url = ti.task.get_extra_links(DEFAULT_DATE, 'Go to QDS') assert url == 'http://localhost/v2/analyze?command_id=12345' - # check for negative case - url2 = task.get_extra_links(datetime(2017, 1, 2), 'Go to QDS') - assert url2 == '' - - def test_extra_serialized_field(self): - dag = DAG(DAG_ID, start_date=DEFAULT_DATE) - with dag: - QuboleOperator( - task_id=TASK_ID, - command_type='shellcmd', - qubole_conn_id=TEST_CONN, - ) + @pytest.mark.need_serialized_dag + def test_extra_serialized_field(self, dag_maker, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + QuboleOperator, + dag_id="test_extra_serialized_field", + execution_date=DEFAULT_DATE, + task_id=TASK_ID, + command_type='shellcmd', + qubole_conn_id=TEST_CONN, + ) - serialized_dag = SerializedDAG.to_dict(dag) + serialized_dag = dag_maker.get_serialized_data() assert "qubole_conn_id" in serialized_dag["dag"]["tasks"][0] dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict[TASK_ID] assert getattr(simple_task, "qubole_conn_id") == TEST_CONN - ######################################################### - # Verify Operator Links work with Serialized Operator - ######################################################### assert isinstance(list(simple_task.operator_extra_links)[0], QDSLink) - ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE) ti.xcom_push('qbol_cmd_id', 12345) - - # check for positive case url = simple_task.get_extra_links(DEFAULT_DATE, 'Go to QDS') assert url == 'http://localhost/v2/analyze?command_id=12345' - # check for negative case - url2 = simple_task.get_extra_links(datetime(2017, 1, 2), 'Go to QDS') - assert url2 == '' - def test_parameter_pool_passed(self): test_pool = 'test_pool' op = QuboleOperator(task_id=TASK_ID, pool=test_pool) diff --git a/tests/providers/sftp/operators/test_sftp.py b/tests/providers/sftp/operators/test_sftp.py index d286184544358..6aa9bb0169866 100644 --- a/tests/providers/sftp/operators/test_sftp.py +++ b/tests/providers/sftp/operators/test_sftp.py @@ -15,40 +15,31 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import os -import unittest from base64 import b64encode from unittest import mock import pytest from airflow.exceptions import AirflowException -from airflow.models import DAG, TaskInstance +from airflow.models import DAG from airflow.providers.sftp.operators.sftp import SFTPOperation, SFTPOperator from airflow.providers.ssh.operators.ssh import SSHOperator from airflow.utils import timezone from airflow.utils.timezone import datetime from tests.test_utils.config import conf_vars -TEST_DAG_ID = 'unit_tests_sftp_op' DEFAULT_DATE = datetime(2017, 1, 1) TEST_CONN_ID = "conn_id_for_testing" -class TestSFTPOperator(unittest.TestCase): - def setUp(self): +class TestSFTPOperator: + def setup_method(self): from airflow.providers.ssh.hooks.ssh import SSHHook hook = SSHHook(ssh_conn_id='ssh_default') hook.no_host_key_check = True - dag = DAG( - TEST_DAG_ID + 'test_schedule_dag_once', - schedule_interval="@once", - start_date=DEFAULT_DATE, - ) self.hook = hook - self.dag = dag self.test_dir = "/tmp" self.test_local_dir = "/tmp/tmp2" self.test_remote_dir = "/tmp/tmp1" @@ -61,8 +52,22 @@ def setUp(self): # Remote Filepath with Intermediate Directory self.test_remote_filepath_int_dir = f'{self.test_remote_dir}/{self.test_remote_filename}' + def teardown_method(self): + if os.path.exists(self.test_local_filepath): + os.remove(self.test_local_filepath) + if os.path.exists(self.test_local_filepath_int_dir): + os.remove(self.test_local_filepath_int_dir) + if os.path.exists(self.test_local_dir): + os.rmdir(self.test_local_dir) + if os.path.exists(self.test_remote_filepath): + os.remove(self.test_remote_filepath) + if os.path.exists(self.test_remote_filepath_int_dir): + os.remove(self.test_remote_filepath_int_dir) + if os.path.exists(self.test_remote_dir): + os.rmdir(self.test_remote_dir) + @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_pickle_file_transfer_put(self): + def test_pickle_file_transfer_put(self, dag_maker): test_local_file_content = ( b"This is local file content \n which is multiline " b"continuing....with other character\nanother line here \n this is last line" @@ -71,38 +76,31 @@ def test_pickle_file_transfer_put(self): with open(self.test_local_filepath, 'wb') as file: file.write(test_local_file_content) - # put test file to remote - put_test_task = SFTPOperator( - task_id="put_test_task", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.PUT, - create_intermediate_dirs=True, - dag=self.dag, - ) - assert put_test_task is not None - ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) - ti2.run() + with dag_maker(dag_id="unit_tests_sftp_op_pickle_file_transfer_put", start_date=DEFAULT_DATE): + SFTPOperator( # Put test file to remote. + task_id="put_test_task", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + create_intermediate_dirs=True, + ) + SSHOperator( # Check the remote file content. + task_id="check_file_task", + ssh_hook=self.hook, + command=f"cat {self.test_remote_filepath}", + do_xcom_push=True, + ) - # check the remote file content - check_file_task = SSHOperator( - task_id="check_file_task", - ssh_hook=self.hook, - command=f"cat {self.test_remote_filepath}", - do_xcom_push=True, - dag=self.dag, - ) - assert check_file_task is not None - ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) - ti3.run() - assert ( - ti3.xcom_pull(task_ids=check_file_task.task_id, key='return_value').strip() - == test_local_file_content - ) + tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} + tis["put_test_task"].run() + tis["check_file_task"].run() + + pulled = tis["check_file_task"].xcom_pull(task_ids="check_file_task", key='return_value') + assert pulled.strip() == test_local_file_content @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_file_transfer_no_intermediate_dir_error_put(self): + def test_file_transfer_no_intermediate_dir_error_put(self, create_task_instance_of_operator): test_local_file_content = ( b"This is local file content \n which is multiline " b"continuing....with other character\nanother line here \n this is last line" @@ -111,26 +109,25 @@ def test_file_transfer_no_intermediate_dir_error_put(self): with open(self.test_local_filepath, 'wb') as file: file.write(test_local_file_content) - # Try to put test file to remote - # This should raise an error with "No such file" as the directory - # does not exist + # Try to put test file to remote. This should raise an error with + # "No such file" as the directory does not exist. + ti2 = create_task_instance_of_operator( + SFTPOperator, + dag_id="unit_tests_sftp_op_file_transfer_no_intermediate_dir_error_put", + execution_date=timezone.utcnow(), + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath_int_dir, + operation=SFTPOperation.PUT, + create_intermediate_dirs=False, + ) with pytest.raises(Exception) as ctx: - put_test_task = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath_int_dir, - operation=SFTPOperation.PUT, - create_intermediate_dirs=False, - dag=self.dag, - ) - assert put_test_task is not None - ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) ti2.run() assert 'No such file' in str(ctx.value) @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_file_transfer_with_intermediate_dir_put(self): + def test_file_transfer_with_intermediate_dir_put(self, dag_maker): test_local_file_content = ( b"This is local file content \n which is multiline " b"continuing....with other character\nanother line here \n this is last line" @@ -139,37 +136,32 @@ def test_file_transfer_with_intermediate_dir_put(self): with open(self.test_local_filepath, 'wb') as file: file.write(test_local_file_content) - # put test file to remote - put_test_task = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath_int_dir, - operation=SFTPOperation.PUT, - create_intermediate_dirs=True, - dag=self.dag, - ) - assert put_test_task is not None - ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) - ti2.run() + with dag_maker(dag_id="unit_tests_sftp_op_file_transfer_with_intermediate_dir_put"): + SFTPOperator( # Put test file to remote. + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath_int_dir, + operation=SFTPOperation.PUT, + create_intermediate_dirs=True, + ) + SSHOperator( # Check the remote file content. + task_id="test_check_file", + ssh_hook=self.hook, + command=f"cat {self.test_remote_filepath_int_dir}", + do_xcom_push=True, + ) - # check the remote file content - check_file_task = SSHOperator( - task_id="test_check_file", - ssh_hook=self.hook, - command=f"cat {self.test_remote_filepath_int_dir}", - do_xcom_push=True, - dag=self.dag, - ) - assert check_file_task is not None - ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) - ti3.run() - assert ( - ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip() == test_local_file_content - ) + dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) + tis = {ti.task_id: ti for ti in dagrun.task_instances} + tis["test_sftp"].run() + tis["test_check_file"].run() + + pulled = tis["test_check_file"].xcom_pull(task_ids='test_check_file', key='return_value') + assert pulled.strip() == test_local_file_content @conf_vars({('core', 'enable_xcom_pickling'): 'False'}) - def test_json_file_transfer_put(self): + def test_json_file_transfer_put(self, dag_maker): test_local_file_content = ( b"This is local file content \n which is multiline " b"continuing....with other character\nanother line here \n this is last line" @@ -178,180 +170,148 @@ def test_json_file_transfer_put(self): with open(self.test_local_filepath, 'wb') as file: file.write(test_local_file_content) - # put test file to remote - put_test_task = SFTPOperator( - task_id="put_test_task", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.PUT, - dag=self.dag, - ) - assert put_test_task is not None - ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) - ti2.run() + with dag_maker(dag_id="unit_tests_sftp_op_json_file_transfer_put"): + SFTPOperator( # Put test file to remote. + task_id="put_test_task", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + ) + SSHOperator( # Check the remote file content. + task_id="check_file_task", + ssh_hook=self.hook, + command=f"cat {self.test_remote_filepath}", + do_xcom_push=True, + ) - # check the remote file content - check_file_task = SSHOperator( - task_id="check_file_task", - ssh_hook=self.hook, - command=f"cat {self.test_remote_filepath}", - do_xcom_push=True, - dag=self.dag, - ) - assert check_file_task is not None - ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) - ti3.run() - assert ti3.xcom_pull(task_ids=check_file_task.task_id, key='return_value').strip() == b64encode( - test_local_file_content - ).decode('utf-8') + dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) + tis = {ti.task_id: ti for ti in dagrun.task_instances} + tis["put_test_task"].run() + tis["check_file_task"].run() + + pulled = tis["check_file_task"].xcom_pull(task_ids="check_file_task", key='return_value') + assert pulled.strip() == b64encode(test_local_file_content).decode('utf-8') @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_pickle_file_transfer_get(self): + def test_pickle_file_transfer_get(self, dag_maker): test_remote_file_content = ( "This is remote file content \n which is also multiline " "another line here \n this is last line. EOF" ) - # create a test file remotely - create_file_task = SSHOperator( - task_id="test_create_file", - ssh_hook=self.hook, - command=f"echo '{test_remote_file_content}' > {self.test_remote_filepath}", - do_xcom_push=True, - dag=self.dag, - ) - assert create_file_task is not None - ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) - ti1.run() + with dag_maker(dag_id="unit_tests_sftp_op_pickle_file_transfer_get"): + SSHOperator( # Create a test file on remote. + task_id="test_create_file", + ssh_hook=self.hook, + command=f"echo '{test_remote_file_content}' > {self.test_remote_filepath}", + do_xcom_push=True, + ) + SFTPOperator( # Get remote file to local. + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.GET, + ) - # get remote file to local - get_test_task = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.GET, - dag=self.dag, - ) - assert get_test_task is not None - ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) - ti2.run() + for ti in dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances: + ti.run() - # test the received content - content_received = None + # Test the received content. with open(self.test_local_filepath) as file: content_received = file.read() assert content_received.strip() == test_remote_file_content @conf_vars({('core', 'enable_xcom_pickling'): 'False'}) - def test_json_file_transfer_get(self): + def test_json_file_transfer_get(self, dag_maker): test_remote_file_content = ( "This is remote file content \n which is also multiline " "another line here \n this is last line. EOF" ) - # create a test file remotely - create_file_task = SSHOperator( - task_id="test_create_file", - ssh_hook=self.hook, - command=f"echo '{test_remote_file_content}' > {self.test_remote_filepath}", - do_xcom_push=True, - dag=self.dag, - ) - assert create_file_task is not None - ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) - ti1.run() + with dag_maker(dag_id="unit_tests_sftp_op_json_file_transfer_get"): + SSHOperator( # Create a test file on remote. + task_id="test_create_file", + ssh_hook=self.hook, + command=f"echo '{test_remote_file_content}' > {self.test_remote_filepath}", + do_xcom_push=True, + ) + SFTPOperator( # Get remote file to local. + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.GET, + ) - # get remote file to local - get_test_task = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.GET, - dag=self.dag, - ) - assert get_test_task is not None - ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) - ti2.run() + for ti in dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances: + ti.run() - # test the received content + # Test the received content. content_received = None with open(self.test_local_filepath) as file: content_received = file.read() assert content_received.strip() == test_remote_file_content.encode('utf-8').decode('utf-8') @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_file_transfer_no_intermediate_dir_error_get(self): + def test_file_transfer_no_intermediate_dir_error_get(self, dag_maker): test_remote_file_content = ( "This is remote file content \n which is also multiline " "another line here \n this is last line. EOF" ) - # create a test file remotely - create_file_task = SSHOperator( - task_id="test_create_file", - ssh_hook=self.hook, - command=f"echo '{test_remote_file_content}' > {self.test_remote_filepath}", - do_xcom_push=True, - dag=self.dag, - ) - assert create_file_task is not None - ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) - ti1.run() - - # Try to GET test file from remote - # This should raise an error with "No such file" as the directory - # does not exist - with pytest.raises(Exception) as ctx: - get_test_task = SFTPOperator( + with dag_maker(dag_id="unit_tests_sftp_op_file_transfer_no_intermediate_dir_error_get"): + SSHOperator( # Create a test file on remote. + task_id="test_create_file", + ssh_hook=self.hook, + command=f"echo '{test_remote_file_content}' > {self.test_remote_filepath}", + do_xcom_push=True, + ) + SFTPOperator( # Try to GET test file from remote. task_id="test_sftp", ssh_hook=self.hook, local_filepath=self.test_local_filepath_int_dir, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.GET, - dag=self.dag, ) - assert get_test_task is not None - ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) + + ti1, ti2 = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances + ti1.run() + + # This should raise an error with "No such file" as the directory + # does not exist. + with pytest.raises(Exception) as ctx: ti2.run() assert 'No such file' in str(ctx.value) @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_file_transfer_with_intermediate_dir_error_get(self): + def test_file_transfer_with_intermediate_dir_error_get(self, dag_maker): test_remote_file_content = ( "This is remote file content \n which is also multiline " "another line here \n this is last line. EOF" ) - # create a test file remotely - create_file_task = SSHOperator( - task_id="test_create_file", - ssh_hook=self.hook, - command=f"echo '{test_remote_file_content}' > {self.test_remote_filepath}", - do_xcom_push=True, - dag=self.dag, - ) - assert create_file_task is not None - ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) - ti1.run() + with dag_maker(dag_id="unit_tests_sftp_op_file_transfer_with_intermediate_dir_error_get"): + SSHOperator( # Create a test file on remote. + task_id="test_create_file", + ssh_hook=self.hook, + command=f"echo '{test_remote_file_content}' > {self.test_remote_filepath}", + do_xcom_push=True, + ) + SFTPOperator( # Get remote file to local. + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath_int_dir, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.GET, + create_intermediate_dirs=True, + ) - # get remote file to local - get_test_task = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath_int_dir, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.GET, - create_intermediate_dirs=True, - dag=self.dag, - ) - assert get_test_task is not None - ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) - ti2.run() + for ti in dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances: + ti.run() - # test the received content + # Test the received content. content_received = None with open(self.test_local_filepath_int_dir) as file: content_received = file.read() @@ -359,6 +319,7 @@ def test_file_transfer_with_intermediate_dir_error_get(self): @mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"}) def test_arg_checking(self): + dag = DAG(dag_id="unit_tests_sftp_op_arg_checking", default_args={"start_date": DEFAULT_DATE}) # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided with pytest.raises(AirflowException, match="Cannot operate without ssh_hook or ssh_conn_id."): task_0 = SFTPOperator( @@ -366,7 +327,7 @@ def test_arg_checking(self): local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, - dag=self.dag, + dag=dag, ) task_0.execute(None) @@ -378,7 +339,7 @@ def test_arg_checking(self): local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, - dag=self.dag, + dag=dag, ) try: task_1.execute(None) @@ -392,7 +353,7 @@ def test_arg_checking(self): local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, - dag=self.dag, + dag=dag, ) try: task_2.execute(None) @@ -408,40 +369,10 @@ def test_arg_checking(self): local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, - dag=self.dag, + dag=dag, ) try: task_3.execute(None) except Exception: pass assert task_3.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id - - def delete_local_resource(self): - if os.path.exists(self.test_local_filepath): - os.remove(self.test_local_filepath) - if os.path.exists(self.test_local_filepath_int_dir): - os.remove(self.test_local_filepath_int_dir) - if os.path.exists(self.test_local_dir): - os.rmdir(self.test_local_dir) - - def delete_remote_resource(self): - if os.path.exists(self.test_remote_filepath): - # check the remote file content - remove_file_task = SSHOperator( - task_id="test_check_file", - ssh_hook=self.hook, - command=f"rm {self.test_remote_filepath}", - do_xcom_push=True, - dag=self.dag, - ) - assert remove_file_task is not None - ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow()) - ti3.run() - if os.path.exists(self.test_remote_filepath_int_dir): - os.remove(self.test_remote_filepath_int_dir) - if os.path.exists(self.test_remote_dir): - os.rmdir(self.test_remote_dir) - - def tearDown(self): - self.delete_local_resource() - self.delete_remote_resource() diff --git a/tests/providers/ssh/operators/test_ssh.py b/tests/providers/ssh/operators/test_ssh.py index 7083354301b70..05bf7103852e7 100644 --- a/tests/providers/ssh/operators/test_ssh.py +++ b/tests/providers/ssh/operators/test_ssh.py @@ -20,12 +20,10 @@ from base64 import b64encode import pytest -from parameterized import parameterized from airflow.exceptions import AirflowException -from airflow.models import DAG, TaskInstance +from airflow.models import DAG from airflow.providers.ssh.operators.ssh import SSHOperator -from airflow.utils import timezone from airflow.utils.timezone import datetime from tests.test_utils.config import conf_vars @@ -37,118 +35,97 @@ COMMAND_WITH_SUDO = "sudo " + COMMAND -class TestSSHOperator(unittest.TestCase): - def setUp(self): +class TestSSHOperator: + def setup_method(self): from airflow.providers.ssh.hooks.ssh import SSHHook hook = SSHHook(ssh_conn_id='ssh_default') hook.no_host_key_check = True - - dag = DAG( - TEST_DAG_ID + 'test_schedule_dag_once', - schedule_interval="@once", - start_date=DEFAULT_DATE, - ) self.hook = hook - self.dag = dag def test_hook_created_correctly(self): timeout = 20 ssh_id = "ssh_default" - task = SSHOperator( - task_id="test", command=COMMAND, dag=self.dag, timeout=timeout, ssh_conn_id="ssh_default" - ) - assert task is not None - + with DAG('unit_tests_ssh_test_op_arg_checking', default_args={'start_date': DEFAULT_DATE}): + task = SSHOperator(task_id="test", command=COMMAND, timeout=timeout, ssh_conn_id="ssh_default") task.execute(None) - assert timeout == task.ssh_hook.timeout assert ssh_id == task.ssh_hook.ssh_conn_id @conf_vars({('core', 'enable_xcom_pickling'): 'False'}) - def test_json_command_execution(self): - task = SSHOperator( + def test_json_command_execution(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + SSHOperator, + dag_id="unit_tests_ssh_test_op_json_command_execution", task_id="test", ssh_hook=self.hook, command=COMMAND, do_xcom_push=True, - dag=self.dag, ) - - assert task is not None - - ti = TaskInstance(task=task, execution_date=timezone.utcnow()) ti.run() assert ti.duration is not None assert ti.xcom_pull(task_ids='test', key='return_value') == b64encode(b'airflow').decode('utf-8') @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_pickle_command_execution(self): - task = SSHOperator( + def test_pickle_command_execution(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + SSHOperator, + dag_id="unit_tests_ssh_test_op_pickle_command_execution", task_id="test", ssh_hook=self.hook, command=COMMAND, do_xcom_push=True, - dag=self.dag, ) - - assert task is not None - - ti = TaskInstance(task=task, execution_date=timezone.utcnow()) ti.run() assert ti.duration is not None assert ti.xcom_pull(task_ids='test', key='return_value') == b'airflow' - def test_command_execution_with_env(self): - task = SSHOperator( + @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) + def test_command_execution_with_env(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + SSHOperator, + dag_id="unit_tests_ssh_test_op_command_execution_with_env", task_id="test", ssh_hook=self.hook, command=COMMAND, do_xcom_push=True, - dag=self.dag, environment={'TEST': 'value'}, ) + ti.run() + assert ti.duration is not None + assert ti.xcom_pull(task_ids='test', key='return_value') == b'airflow' - assert task is not None - - with conf_vars({('core', 'enable_xcom_pickling'): 'True'}): - ti = TaskInstance(task=task, execution_date=timezone.utcnow()) - ti.run() - assert ti.duration is not None - assert ti.xcom_pull(task_ids='test', key='return_value') == b'airflow' - - def test_no_output_command(self): - task = SSHOperator( + @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) + def test_no_output_command(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + SSHOperator, + dag_id="unit_tests_ssh_test_op_no_output_command", task_id="test", ssh_hook=self.hook, command="sleep 1", do_xcom_push=True, - dag=self.dag, ) - - assert task is not None - - with conf_vars({('core', 'enable_xcom_pickling'): 'True'}): - ti = TaskInstance(task=task, execution_date=timezone.utcnow()) - ti.run() - assert ti.duration is not None - assert ti.xcom_pull(task_ids='test', key='return_value') == b'' + ti.run() + assert ti.duration is not None + assert ti.xcom_pull(task_ids='test', key='return_value') == b'' @unittest.mock.patch('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"}) def test_arg_checking(self): - # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided + dag = DAG('unit_tests_ssh_test_op_arg_checking', default_args={'start_date': DEFAULT_DATE}) + + # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided. + task_0 = SSHOperator(task_id="test", command=COMMAND, timeout=TIMEOUT, dag=dag) with pytest.raises(AirflowException, match="Cannot operate without ssh_hook or ssh_conn_id."): - task_0 = SSHOperator(task_id="test", command=COMMAND, timeout=TIMEOUT, dag=self.dag) task_0.execute(None) - # if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook + # If ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook. task_1 = SSHOperator( task_id="test_1", - ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook + ssh_hook="string_rather_than_SSHHook", # Invalid ssh_hook. ssh_conn_id=TEST_CONN_ID, command=COMMAND, timeout=TIMEOUT, - dag=self.dag, + dag=dag, ) try: task_1.execute(None) @@ -158,10 +135,10 @@ def test_arg_checking(self): task_2 = SSHOperator( task_id="test_2", - ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided + ssh_conn_id=TEST_CONN_ID, # No ssh_hook provided. command=COMMAND, timeout=TIMEOUT, - dag=self.dag, + dag=dag, ) try: task_2.execute(None) @@ -169,41 +146,42 @@ def test_arg_checking(self): pass assert task_2.ssh_hook.ssh_conn_id == TEST_CONN_ID - # if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id + # If both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id. task_3 = SSHOperator( task_id="test_3", ssh_hook=self.hook, ssh_conn_id=TEST_CONN_ID, command=COMMAND, timeout=TIMEOUT, - dag=self.dag, + dag=dag, ) - try: - task_3.execute(None) - except Exception: - pass + task_3.execute(None) assert task_3.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id - @parameterized.expand( + @pytest.mark.parametrize( + "command, get_pty_in, get_pty_out", [ (COMMAND, False, False), (COMMAND, True, True), (COMMAND_WITH_SUDO, False, True), (COMMAND_WITH_SUDO, True, True), (None, True, True), - ] + ], ) def test_get_pyt_set_correctly(self, command, get_pty_in, get_pty_out): + dag = DAG('unit_tests_ssh_test_op_arg_checking', default_args={'start_date': DEFAULT_DATE}) task = SSHOperator( task_id="test", ssh_hook=self.hook, command=command, timeout=TIMEOUT, get_pty=get_pty_in, - dag=self.dag, + dag=dag, ) - try: + if command is None: + with pytest.raises(AirflowException) as ctx: + task.execute(None) + assert str(ctx.value) == "SSH operator error: SSH command not specified. Aborting." + else: task.execute(None) - except Exception: - pass assert task.get_pty == get_pty_out diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py index dd3bf29a0ad49..b519f8df13e53 100644 --- a/tests/sensors/test_base.py +++ b/tests/sensors/test_base.py @@ -17,7 +17,6 @@ # under the License. -import unittest from datetime import timedelta from unittest.mock import Mock, patch @@ -25,15 +24,13 @@ from freezegun import freeze_time from airflow.exceptions import AirflowException, AirflowRescheduleException, AirflowSensorTimeout -from airflow.models import DagBag, TaskInstance, TaskReschedule -from airflow.models.dag import DAG +from airflow.models import TaskReschedule from airflow.operators.dummy import DummyOperator from airflow.sensors.base import BaseSensorOperator, poke_mode_only from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep from airflow.utils import timezone from airflow.utils.state import State from airflow.utils.timezone import datetime -from airflow.utils.types import DagRunType from tests.test_utils import db DEFAULT_DATE = datetime(2015, 1, 1) @@ -52,51 +49,50 @@ def poke(self, context): return self.return_value -class TestBaseSensor(unittest.TestCase): +class TestBaseSensor: @staticmethod def clean_db(): db.clear_db_runs() db.clear_db_task_reschedule() db.clear_db_xcom() - def setUp(self): - args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} - self.dag = DAG(TEST_DAG_ID, default_args=args) + @pytest.fixture(autouse=True) + def _auto_clean(self, dag_maker): + """(auto use)""" self.clean_db() - def tearDown(self) -> None: + yield + self.clean_db() - def _make_dag_run(self): - return self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, - ) + @pytest.fixture + def make_sensor(self, dag_maker): + """Create a DummySensor and associated DagRun""" - def _make_sensor(self, return_value, task_id=SENSOR_OP, **kwargs): - poke_interval = 'poke_interval' - timeout = 'timeout' + def _make_sensor(return_value, task_id=SENSOR_OP, **kwargs): + poke_interval = 'poke_interval' + timeout = 'timeout' - if poke_interval not in kwargs: - kwargs[poke_interval] = 0 - if timeout not in kwargs: - kwargs[timeout] = 0 + if poke_interval not in kwargs: + kwargs[poke_interval] = 0 + if timeout not in kwargs: + kwargs[timeout] = 0 - sensor = DummySensor(task_id=task_id, return_value=return_value, dag=self.dag, **kwargs) + with dag_maker(TEST_DAG_ID): + sensor = DummySensor(task_id=task_id, return_value=return_value, **kwargs) - dummy_op = DummyOperator(task_id=DUMMY_OP, dag=self.dag) - dummy_op.set_upstream(sensor) - return sensor + dummy_op = DummyOperator(task_id=DUMMY_OP) + sensor >> dummy_op + return sensor, dag_maker.create_dagrun() + + return _make_sensor @classmethod - def _run(cls, task): - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + def _run(cls, task, **kwargs): + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True, **kwargs) - def test_ok(self): - sensor = self._make_sensor(True) - dr = self._make_dag_run() + def test_ok(self, make_sensor): + sensor, dr = make_sensor(True) self._run(sensor) tis = dr.get_task_instances() @@ -107,9 +103,8 @@ def test_ok(self): if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_fail(self): - sensor = self._make_sensor(False) - dr = self._make_dag_run() + def test_fail(self, make_sensor): + sensor, dr = make_sensor(False) with pytest.raises(AirflowSensorTimeout): self._run(sensor) @@ -121,9 +116,8 @@ def test_fail(self): if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_soft_fail(self): - sensor = self._make_sensor(False, soft_fail=True) - dr = self._make_dag_run() + def test_soft_fail(self, make_sensor): + sensor, dr = make_sensor(False, soft_fail=True) self._run(sensor) tis = dr.get_task_instances() @@ -134,11 +128,10 @@ def test_soft_fail(self): if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_soft_fail_with_retries(self): - sensor = self._make_sensor( + def test_soft_fail_with_retries(self, make_sensor): + sensor, dr = make_sensor( return_value=False, soft_fail=True, retries=1, retry_delay=timedelta(milliseconds=1) ) - dr = self._make_dag_run() # first run times out and task instance is skipped self._run(sensor) @@ -150,10 +143,9 @@ def test_soft_fail_with_retries(self): if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_ok_with_reschedule(self): - sensor = self._make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule') + def test_ok_with_reschedule(self, make_sensor): + sensor, dr = make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule') sensor.poke = Mock(side_effect=[False, False, True]) - dr = self._make_dag_run() # first poke returns False and task is re-scheduled date1 = timezone.utcnow() @@ -209,9 +201,8 @@ def test_ok_with_reschedule(self): if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_fail_with_reschedule(self): - sensor = self._make_sensor(return_value=False, poke_interval=10, timeout=5, mode='reschedule') - dr = self._make_dag_run() + def test_fail_with_reschedule(self, make_sensor): + sensor, dr = make_sensor(return_value=False, poke_interval=10, timeout=5, mode='reschedule') # first poke returns False and task is re-scheduled date1 = timezone.utcnow() @@ -238,11 +229,10 @@ def test_fail_with_reschedule(self): if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_soft_fail_with_reschedule(self): - sensor = self._make_sensor( + def test_soft_fail_with_reschedule(self, make_sensor): + sensor, dr = make_sensor( return_value=False, poke_interval=10, timeout=5, soft_fail=True, mode='reschedule' ) - dr = self._make_dag_run() # first poke returns False and task is re-scheduled date1 = timezone.utcnow() @@ -268,8 +258,8 @@ def test_soft_fail_with_reschedule(self): if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_ok_with_reschedule_and_retry(self): - sensor = self._make_sensor( + def test_ok_with_reschedule_and_retry(self, make_sensor): + sensor, dr = make_sensor( return_value=None, poke_interval=10, timeout=5, @@ -278,7 +268,6 @@ def test_ok_with_reschedule_and_retry(self): mode='reschedule', ) sensor.poke = Mock(side_effect=[False, False, False, True]) - dr = self._make_dag_run() # first poke returns False and task is re-scheduled date1 = timezone.utcnow() @@ -345,28 +334,27 @@ def test_ok_with_reschedule_and_retry(self): assert ti.state == State.NONE def test_should_include_ready_to_reschedule_dep_in_reschedule_mode(self): - sensor = self._make_sensor(True, mode='reschedule') + sensor = DummySensor(task_id='a', return_value=True, mode='reschedule') deps = sensor.deps assert ReadyToRescheduleDep() in deps - def test_should_not_include_ready_to_reschedule_dep_in_poke_mode(self): - sensor = self._make_sensor(True) + def test_should_not_include_ready_to_reschedule_dep_in_poke_mode(self, make_sensor): + sensor = DummySensor(task_id='a', return_value=False, mode='poke') deps = sensor.deps assert ReadyToRescheduleDep() not in deps def test_invalid_mode(self): with pytest.raises(AirflowException): - self._make_sensor(return_value=True, mode='foo') + DummySensor(task_id='a', mode='foo') - def test_ok_with_custom_reschedule_exception(self): - sensor = self._make_sensor(return_value=None, mode='reschedule') + def test_ok_with_custom_reschedule_exception(self, make_sensor): + sensor, dr = make_sensor(return_value=None, mode='reschedule') date1 = timezone.utcnow() date2 = date1 + timedelta(seconds=60) date3 = date1 + timedelta(seconds=120) sensor.poke = Mock( side_effect=[AirflowRescheduleException(date2), AirflowRescheduleException(date3), True] ) - dr = self._make_dag_run() # first poke returns False and task is re-scheduled with freeze_time(date1): @@ -413,16 +401,14 @@ def test_ok_with_custom_reschedule_exception(self): if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_reschedule_with_test_mode(self): - sensor = self._make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule') + def test_reschedule_with_test_mode(self, make_sensor): + sensor, dr = make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule') sensor.poke = Mock(side_effect=[False]) - dr = self._make_dag_run() # poke returns False and AirflowRescheduleException is raised date1 = timezone.utcnow() with freeze_time(date1): - for info in self.dag.iter_dagrun_infos_between(DEFAULT_DATE, DEFAULT_DATE): - TaskInstance(sensor, info.logical_date).run(ignore_ti_state=True, test_mode=True) + self._run(sensor, test_mode=True) tis = dr.get_task_instances() assert len(tis) == 2 for ti in tis: @@ -440,7 +426,7 @@ def test_sensor_with_invalid_poke_interval(self): non_number_poke_interval = "abcd" positive_poke_interval = 10 with pytest.raises(AirflowException): - self._make_sensor( + DummySensor( task_id='test_sensor_task_1', return_value=None, poke_interval=negative_poke_interval, @@ -448,14 +434,14 @@ def test_sensor_with_invalid_poke_interval(self): ) with pytest.raises(AirflowException): - self._make_sensor( + DummySensor( task_id='test_sensor_task_2', return_value=None, poke_interval=non_number_poke_interval, timeout=25, ) - self._make_sensor( + DummySensor( task_id='test_sensor_task_3', return_value=None, poke_interval=positive_poke_interval, timeout=25 ) @@ -464,21 +450,23 @@ def test_sensor_with_invalid_timeout(self): non_number_timeout = "abcd" positive_timeout = 25 with pytest.raises(AirflowException): - self._make_sensor( + DummySensor( task_id='test_sensor_task_1', return_value=None, poke_interval=10, timeout=negative_timeout ) with pytest.raises(AirflowException): - self._make_sensor( + DummySensor( task_id='test_sensor_task_2', return_value=None, poke_interval=10, timeout=non_number_timeout ) - self._make_sensor( + DummySensor( task_id='test_sensor_task_3', return_value=None, poke_interval=10, timeout=positive_timeout ) def test_sensor_with_exponential_backoff_off(self): - sensor = self._make_sensor(return_value=None, poke_interval=5, timeout=60, exponential_backoff=False) + sensor = DummySensor( + task_id=SENSOR_OP, return_value=None, poke_interval=5, timeout=60, exponential_backoff=False + ) started_at = timezone.utcnow() - timedelta(seconds=10) @@ -490,7 +478,9 @@ def run_duration(): def test_sensor_with_exponential_backoff_on(self): - sensor = self._make_sensor(return_value=None, poke_interval=5, timeout=60, exponential_backoff=True) + sensor = DummySensor( + task_id=SENSOR_OP, return_value=None, poke_interval=5, timeout=60, exponential_backoff=True + ) with patch('airflow.utils.timezone.utcnow') as mock_utctime: mock_utctime.return_value = DEFAULT_DATE @@ -508,7 +498,7 @@ def run_duration(): assert interval2 >= sensor.poke_interval assert interval2 > interval1 - def test_reschedule_and_retry_timeout(self): + def test_reschedule_and_retry_timeout(self, make_sensor): """ Test mode="reschedule", retries and timeout configurations interact correctly. @@ -534,7 +524,7 @@ def test_reschedule_and_retry_timeout(self): 00:26 Returns False try_number=3, max_tries=4, state=up_for_reschedule 00:31 Raises AirflowSensorTimeout, try_number=4, max_tries=4, state=failed """ - sensor = self._make_sensor( + sensor, dr = make_sensor( return_value=None, poke_interval=5, timeout=10, @@ -544,18 +534,17 @@ def test_reschedule_and_retry_timeout(self): ) sensor.poke = Mock(side_effect=[False, RuntimeError, False, False, False, False, False, False]) - dr = self._make_dag_run() def assert_ti_state(try_number, max_tries, state): tis = dr.get_task_instances() - self.assertEqual(len(tis), 2) + assert len(tis) == 2 for ti in tis: if ti.task_id == SENSOR_OP: - self.assertEqual(ti.try_number, try_number) - self.assertEqual(ti.max_tries, max_tries) - self.assertEqual(ti.state, state) + assert ti.try_number == try_number + assert ti.max_tries == max_tries + assert ti.state == state break else: self.fail("sensor not found") @@ -568,9 +557,8 @@ def assert_ti_state(try_number, max_tries, state): # second poke raises RuntimeError and task instance retries date2 = date1 + timedelta(seconds=sensor.poke_interval) - with freeze_time(date2): - with self.assertRaises(RuntimeError): - self._run(sensor) + with freeze_time(date2), pytest.raises(RuntimeError): + self._run(sensor) assert_ti_state(2, 2, State.UP_FOR_RETRY) # third poke returns False and task is rescheduled again @@ -581,9 +569,8 @@ def assert_ti_state(try_number, max_tries, state): # fourth poke times out and raises AirflowSensorTimeout date4 = date3 + timedelta(seconds=sensor.poke_interval) - with freeze_time(date4): - with self.assertRaises(AirflowSensorTimeout): - self._run(sensor) + with freeze_time(date4), pytest.raises(AirflowSensorTimeout): + self._run(sensor) assert_ti_state(3, 2, State.FAILED) # Clear the failed sensor @@ -599,9 +586,8 @@ def assert_ti_state(try_number, max_tries, state): # Last poke times out and raises AirflowSensorTimeout date8 = date_i + timedelta(seconds=sensor.poke_interval) - with freeze_time(date8): - with self.assertRaises(AirflowSensorTimeout): - self._run(sensor) + with freeze_time(date8), pytest.raises(AirflowSensorTimeout): + self._run(sensor) assert_ti_state(4, 4, State.FAILED) @@ -622,15 +608,10 @@ def change_mode(self, mode): self.mode = mode -class TestPokeModeOnly(unittest.TestCase): - def setUp(self): - self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True) - self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} - self.dag = DAG(TEST_DAG_ID, default_args=self.args) - +class TestPokeModeOnly: def test_poke_mode_only_allows_poke_mode(self): try: - sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False, dag=self.dag) + sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False) except ValueError: self.fail("__init__ failed with mode='poke'.") try: @@ -643,15 +624,15 @@ def test_poke_mode_only_allows_poke_mode(self): self.fail("class method failed without changing mode from 'poke'.") def test_poke_mode_only_bad_class_method(self): - sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False, dag=self.dag) + sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False) with pytest.raises(ValueError): sensor.change_mode('reschedule') def test_poke_mode_only_bad_init(self): with pytest.raises(ValueError): - DummyPokeOnlySensor(task_id='foo', mode='reschedule', poke_changes_mode=False, dag=self.dag) + DummyPokeOnlySensor(task_id='foo', mode='reschedule', poke_changes_mode=False) def test_poke_mode_only_bad_poke(self): - sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=True, dag=self.dag) + sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=True) with pytest.raises(ValueError): sensor.poke({}) diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 5b50fc22bd4ca..7387e1ce8e3d1 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -23,15 +23,17 @@ from airflow import exceptions, settings from airflow.exceptions import AirflowException, AirflowSensorTimeout -from airflow.models import DagBag, TaskInstance +from airflow.models import DagBag, DagRun, TaskInstance from airflow.models.dag import DAG from airflow.operators.bash import BashOperator from airflow.operators.dummy import DummyOperator from airflow.sensors.external_task import ExternalTaskMarker, ExternalTaskSensor from airflow.sensors.time_sensor import TimeSensor from airflow.serialization.serialized_objects import SerializedBaseOperator -from airflow.utils.state import State +from airflow.utils.session import provide_session +from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.timezone import datetime +from airflow.utils.types import DagRunType from tests.test_utils.db import clear_db_runs DEFAULT_DATE = datetime(2015, 1, 1) @@ -170,18 +172,6 @@ def test_external_dag_sensor(self): ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - def test_templated_sensor(self): - with self.dag: - sensor = ExternalTaskSensor( - task_id='templated_task', external_dag_id='dag_{{ ds }}', external_task_id='task_{{ ds }}' - ) - - instance = TaskInstance(sensor, DEFAULT_DATE) - instance.render_templates() - - assert sensor.external_dag_id == f"dag_{DEFAULT_DATE.date()}" - assert sensor.external_task_id == f"task_{DEFAULT_DATE.date()}" - def test_external_task_sensor_fn_multiple_execution_dates(self): bash_command_code = """ {% set s=execution_date.time().second %} @@ -417,6 +407,21 @@ def test_external_task_sensor_waits_for_dag_check_existence(self): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) +def test_external_task_sensor_templated(dag_maker): + with dag_maker(): + ExternalTaskSensor( + task_id='templated_task', + external_dag_id='dag_{{ ds }}', + external_task_id='task_{{ ds }}', + ) + + (instance,) = dag_maker.create_dagrun(execution_date=DEFAULT_DATE).task_instances + instance.render_templates() + + assert instance.task.external_dag_id == f"dag_{DEFAULT_DATE.date()}" + assert instance.task.external_task_id == f"task_{DEFAULT_DATE.date()}" + + class TestExternalTaskMarker(unittest.TestCase): def test_serialized_fields(self): assert {"recursion_depth"}.issubset(ExternalTaskMarker.get_serialized_fields()) @@ -454,6 +459,8 @@ def dag_bag_ext(): | dag_3: ---> task_a_3 >> task_b_3 """ + clear_db_runs() + dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False) dag_0 = DAG("dag_0", start_date=DEFAULT_DATE, schedule_interval=None) @@ -491,7 +498,9 @@ def dag_bag_ext(): for dag in [dag_0, dag_1, dag_2, dag_3]: dag_bag.bag_dag(dag=dag, root_dag=dag) - return dag_bag + yield dag_bag + + clear_db_runs() @pytest.fixture @@ -511,37 +520,40 @@ def dag_bag_parent_child(): child_dag_1 task_1 task_1 """ + clear_db_runs() + dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False) day_1 = DEFAULT_DATE - dag_0 = DAG("parent_dag_0", start_date=day_1, schedule_interval=None) - task_0 = ExternalTaskMarker( - task_id="task_0", - external_dag_id="child_dag_1", - external_task_id="task_1", - execution_date=day_1.isoformat(), - recursion_depth=3, - dag=dag_0, - ) + with DAG("parent_dag_0", start_date=day_1, schedule_interval=None) as dag_0: + task_0 = ExternalTaskMarker( + task_id="task_0", + external_dag_id="child_dag_1", + external_task_id="task_1", + execution_date=day_1.isoformat(), + recursion_depth=3, + ) - dag_1 = DAG("child_dag_1", start_date=day_1, schedule_interval=None) - _ = ExternalTaskSensor( - task_id="task_1", - external_dag_id=dag_0.dag_id, - external_task_id=task_0.task_id, - execution_date_fn=lambda execution_date: day_1 if execution_date == day_1 else [], - mode='reschedule', - dag=dag_1, - ) + with DAG("child_dag_1", start_date=day_1, schedule_interval=None) as dag_1: + ExternalTaskSensor( + task_id="task_1", + external_dag_id=dag_0.dag_id, + external_task_id=task_0.task_id, + execution_date_fn=lambda execution_date: day_1 if execution_date == day_1 else [], + mode='reschedule', + ) for dag in [dag_0, dag_1]: dag_bag.bag_dag(dag=dag, root_dag=dag) - return dag_bag + yield dag_bag + + clear_db_runs() -def run_tasks(dag_bag, execution_date=DEFAULT_DATE): +@provide_session +def run_tasks(dag_bag, execution_date=DEFAULT_DATE, session=None): """ Run all tasks in the DAGs in the given dag_bag. Return the TaskInstance objects as a dict keyed by task_id. @@ -549,10 +561,17 @@ def run_tasks(dag_bag, execution_date=DEFAULT_DATE): tis = {} for dag in dag_bag.dags.values(): - for task in dag.tasks: - ti = TaskInstance(task=task, execution_date=execution_date) - tis[task.task_id] = ti - ti.run() + dagrun = dag.create_dagrun( + state=State.RUNNING, + execution_date=execution_date, + start_date=execution_date, + run_type=DagRunType.MANUAL, + session=session, + ) + for ti in dagrun.task_instances: + ti.refresh_from_task(dag.get_task(ti.task_id)) + tis[ti.task_id] = ti + ti.run(session=session) assert_ti_state_equal(ti, State.SUCCESS) return tis @@ -566,12 +585,27 @@ def assert_ti_state_equal(task_instance, state): assert task_instance.state == state -def clear_tasks(dag_bag, dag, task, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, dry_run=False): +@provide_session +def clear_tasks( + dag_bag, + dag, + task, + session, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, + dry_run=False, +): """ Clear the task and its downstream tasks recursively for the dag in the given dagbag. """ partial: DAG = dag.partial_subset(task_ids_or_regex=[task.task_id], include_downstream=True) - return partial.clear(start_date=start_date, end_date=end_date, dag_bag=dag_bag, dry_run=dry_run) + return partial.clear( + start_date=start_date, + end_date=end_date, + dag_bag=dag_bag, + dry_run=dry_run, + session=session, + ) def test_external_task_marker_transitive(dag_bag_ext): @@ -588,13 +622,11 @@ def test_external_task_marker_transitive(dag_bag_ext): assert_ti_state_equal(ti_b_3, State.NONE) -def test_external_task_marker_clear_activate(dag_bag_parent_child): +@provide_session +def test_external_task_marker_clear_activate(dag_bag_parent_child, session): """ Test clearing tasks across DAGs and make sure the right DagRuns are activated. """ - from airflow.utils.session import create_session - from airflow.utils.types import DagRunType - dag_bag = dag_bag_parent_child day_1 = DEFAULT_DATE day_2 = DEFAULT_DATE + timedelta(days=1) @@ -602,33 +634,23 @@ def test_external_task_marker_clear_activate(dag_bag_parent_child): run_tasks(dag_bag, execution_date=day_1) run_tasks(dag_bag, execution_date=day_2) - with create_session() as session: - for dag in dag_bag.dags.values(): - for execution_date in [day_1, day_2]: - dagrun = dag.create_dagrun( - State.RUNNING, execution_date, run_type=DagRunType.MANUAL, session=session - ) - dagrun.set_state(State.SUCCESS) - session.add(dagrun) - - session.commit() - # Assert that dagruns of all the affected dags are set to SUCCESS before tasks are cleared. for dag in dag_bag.dags.values(): for execution_date in [day_1, day_2]: - dagrun = dag.get_dagrun(execution_date=execution_date) - assert dagrun.state == State.SUCCESS + dagrun = dag.get_dagrun(execution_date=execution_date, session=session) + dagrun.set_state(State.SUCCESS) + session.flush() dag_0 = dag_bag.get_dag("parent_dag_0") task_0 = dag_0.get_task("task_0") - clear_tasks(dag_bag, dag_0, task_0, start_date=day_1, end_date=day_2) + clear_tasks(dag_bag, dag_0, task_0, start_date=day_1, end_date=day_2, session=session) # Assert that dagruns of all the affected dags are set to QUEUED after tasks are cleared. # Unaffected dagruns should be left as SUCCESS. - dagrun_0_1 = dag_bag.get_dag('parent_dag_0').get_dagrun(execution_date=day_1) - dagrun_0_2 = dag_bag.get_dag('parent_dag_0').get_dagrun(execution_date=day_2) - dagrun_1_1 = dag_bag.get_dag('child_dag_1').get_dagrun(execution_date=day_1) - dagrun_1_2 = dag_bag.get_dag('child_dag_1').get_dagrun(execution_date=day_2) + dagrun_0_1 = dag_bag.get_dag('parent_dag_0').get_dagrun(execution_date=day_1, session=session) + dagrun_0_2 = dag_bag.get_dag('parent_dag_0').get_dagrun(execution_date=day_2, session=session) + dagrun_1_1 = dag_bag.get_dag('child_dag_1').get_dagrun(execution_date=day_1, session=session) + dagrun_1_2 = dag_bag.get_dag('child_dag_1').get_dagrun(execution_date=day_2, session=session) assert dagrun_0_1.state == State.QUEUED assert dagrun_0_2.state == State.QUEUED @@ -845,6 +867,7 @@ def dag_bag_head_tail(): +------+ +------+ +------+ """ dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False) + with DAG("head_tail", start_date=DEFAULT_DATE, schedule_interval="@daily") as dag: head = ExternalTaskSensor( task_id='head', @@ -864,22 +887,40 @@ def dag_bag_head_tail(): dag_bag.bag_dag(dag=dag, root_dag=dag) - yield dag_bag - + return dag_bag -def test_clear_overlapping_external_task_marker(dag_bag_head_tail): - dag = dag_bag_head_tail.get_dag("head_tail") - # Mark first head task success. - first = TaskInstance(task=dag.get_task("head"), execution_date=DEFAULT_DATE) - first.run(mark_success=True) +@provide_session +def test_clear_overlapping_external_task_marker(dag_bag_head_tail, session): + dag: DAG = dag_bag_head_tail.get_dag('head_tail') - for delta in range(10): + # "Run" 10 times. + for delta in range(0, 10): execution_date = DEFAULT_DATE + timedelta(days=delta) - run_tasks(dag_bag_head_tail, execution_date=execution_date) + dagrun = DagRun( + dag_id=dag.dag_id, + state=DagRunState.SUCCESS, + execution_date=execution_date, + run_type=DagRunType.MANUAL, + run_id=f"test_{delta}", + ) + session.add(dagrun) + for task in dag.tasks: + ti = TaskInstance(task=task) + dagrun.task_instances.append(ti) + ti.state = TaskInstanceState.SUCCESS + session.flush() # The next two lines are doing the same thing. Clearing the first "head" with "Future" # selected is the same as not selecting "Future". They should take similar amount of # time too because dag.clear() uses visited_external_tis to keep track of visited ExternalTaskMarker. - assert dag.clear(start_date=DEFAULT_DATE, dag_bag=dag_bag_head_tail) == 30 - assert dag.clear(start_date=DEFAULT_DATE, end_date=execution_date, dag_bag=dag_bag_head_tail) == 30 + assert dag.clear(start_date=DEFAULT_DATE, dag_bag=dag_bag_head_tail, session=session) == 30 + assert ( + dag.clear( + start_date=DEFAULT_DATE, + end_date=execution_date, + dag_bag=dag_bag_head_tail, + session=session, + ) + == 30 + ) diff --git a/tests/sensors/test_smart_sensor_operator.py b/tests/sensors/test_smart_sensor_operator.py index 1da0a00f16e8c..6a4e5e9f0c45f 100644 --- a/tests/sensors/test_smart_sensor_operator.py +++ b/tests/sensors/test_smart_sensor_operator.py @@ -45,7 +45,10 @@ class DummySmartSensor(SmartSensorOperator): def __init__( - self, shard_max=conf.getint('smart_sensor', 'shard_code_upper_limit'), shard_min=0, **kwargs + self, + shard_max=conf.getint('smart_sensor', 'shard_code_upper_limit'), + shard_min=0, + **kwargs, ): super().__init__(shard_min=shard_min, shard_max=shard_max, **kwargs) @@ -156,6 +159,7 @@ def _run(cls, task): def test_load_sensor_works(self): # Mock two sensor tasks return True and one return False # The hashcode for si1 and si2 should be same. Test dedup on these two instances + self._make_sensor_dag_run() si1 = self._make_sensor_instance(1, True) si2 = self._make_sensor_instance(2, True) si3 = self._make_sensor_instance(3, False) @@ -286,6 +290,7 @@ def test_smart_operator_timeout(self): assert sensor_instance.state == State.FAILED def test_register_in_sensor_service(self): + self._make_sensor_dag_run() si1 = self._make_sensor_instance(1, True) si1.run(ignore_all_deps=True) assert si1.state == State.SENSING @@ -296,7 +301,9 @@ def test_register_in_sensor_service(self): sensor_instance = ( session.query(SI) .filter( - SI.dag_id == si1.dag_id, SI.task_id == si1.task_id, SI.execution_date == si1.execution_date + SI.dag_id == si1.dag_id, + SI.task_id == si1.task_id, + SI.execution_date == si1.execution_date, ) .first() ) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index b8c7e855e6be8..cafbe8d14ba58 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -23,7 +23,7 @@ import importlib.util import multiprocessing import os -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta from glob import glob from unittest import mock @@ -35,13 +35,15 @@ from airflow.exceptions import SerializationError from airflow.hooks.base import BaseHook from airflow.kubernetes.pod_generator import PodGenerator -from airflow.models import DAG, Connection, DagBag, TaskInstance +from airflow.models import DAG, Connection, DagBag from airflow.models.baseoperator import BaseOperator, BaseOperatorLink +from airflow.models.xcom import XCom from airflow.operators.bash import BashOperator from airflow.security import permissions from airflow.serialization.json_schema import load_dag_schema_dict from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.timetables.simple import NullTimetable, OnceTimetable +from airflow.utils import timezone from tests.test_utils.mock_operators import CustomOperator, CustomOpLink, GoogleLink from tests.test_utils.timetables import CustomSerializationTimetable, cron_timetable, delta_timetable @@ -770,7 +772,7 @@ def test_extra_serialized_field_and_operator_links(self): the Operator in ``BaseOperator.operator_extra_links``, it has the correct extra link. """ - test_date = datetime(2019, 8, 1) + test_date = timezone.DateTime(2019, 8, 1, tzinfo=timezone.utc) dag = DAG(dag_id='simple_dag', start_date=test_date) CustomOperator(task_id='simple_task', dag=dag, bash_command="true") @@ -792,8 +794,13 @@ def test_extra_serialized_field_and_operator_links(self): # Test all the extra_links are set assert set(simple_task.extra_links) == {'Google Custom', 'airflow', 'github', 'google'} - ti = TaskInstance(task=simple_task, execution_date=test_date) - ti.xcom_push('search_query', "dummy_value_1") + XCom.set( + key='search_query', + value="dummy_value_1", + task_id=simple_task.task_id, + dag_id=simple_task.dag_id, + execution_date=test_date, + ) # Test Deserialized inbuilt link custom_inbuilt_link = simple_task.get_extra_links(test_date, CustomOpLink.name) @@ -850,7 +857,7 @@ def test_extra_serialized_field_and_multiple_operator_links(self): the Operator in ``BaseOperator.operator_extra_links``, it has the correct extra link. """ - test_date = datetime(2019, 8, 1) + test_date = timezone.DateTime(2019, 8, 1, tzinfo=timezone.utc) dag = DAG(dag_id='simple_dag', start_date=test_date) CustomOperator(task_id='simple_task', dag=dag, bash_command=["echo", "true"]) @@ -879,8 +886,13 @@ def test_extra_serialized_field_and_multiple_operator_links(self): 'google', } - ti = TaskInstance(task=simple_task, execution_date=test_date) - ti.xcom_push('search_query', ["dummy_value_1", "dummy_value_2"]) + XCom.set( + key='search_query', + value=["dummy_value_1", "dummy_value_2"], + task_id=simple_task.task_id, + dag_id=simple_task.dag_id, + execution_date=test_date, + ) # Test Deserialized inbuilt link #1 custom_inbuilt_link = simple_task.get_extra_links(test_date, "BigQuery Console #1") diff --git a/tests/test_utils/mock_executor.py b/tests/test_utils/mock_executor.py index 104995e915910..37f49cfbd9576 100644 --- a/tests/test_utils/mock_executor.py +++ b/tests/test_utils/mock_executor.py @@ -86,7 +86,7 @@ def change_state(self, key, state, info=None): # a list of all events for testing self.sorted_tasks.append((key, (state, info))) - def mock_task_fail(self, dag_id, task_id, date, try_number=1): + def mock_task_fail(self, dag_id, task_id, run_id: str, try_number=1): """ Set the mock outcome of running this particular task instances to FAILED. @@ -94,4 +94,5 @@ def mock_task_fail(self, dag_id, task_id, date, try_number=1): If the task identified by the tuple ``(dag_id, task_id, date, try_number)`` is run by this executor it's state will be FAILED. """ - self.mock_task_results[TaskInstanceKey(dag_id, task_id, date, try_number)] = State.FAILED + assert isinstance(run_id, str) + self.mock_task_results[TaskInstanceKey(dag_id, task_id, run_id, try_number)] = State.FAILED diff --git a/tests/test_utils/mock_operators.py b/tests/test_utils/mock_operators.py index 67176343b893c..155d17fe3be0d 100644 --- a/tests/test_utils/mock_operators.py +++ b/tests/test_utils/mock_operators.py @@ -20,8 +20,8 @@ import attr -from airflow.models import TaskInstance from airflow.models.baseoperator import BaseOperator, BaseOperatorLink +from airflow.models.xcom import XCom from airflow.providers.apache.hive.operators.hive import HiveOperator @@ -83,8 +83,9 @@ def name(self) -> str: return f'BigQuery Console #{self.index + 1}' def get_link(self, operator, dttm): - ti = TaskInstance(task=operator, execution_date=dttm) - search_queries = ti.xcom_pull(task_ids=operator.task_id, key='search_query') + search_queries = XCom.get_one( + task_id=operator.task_id, dag_id=operator.dag_id, execution_date=dttm, key='search_query' + ) if not search_queries: return None if len(search_queries) < self.index: @@ -97,8 +98,9 @@ class CustomOpLink(BaseOperatorLink): name = 'Google Custom' def get_link(self, operator, dttm): - ti = TaskInstance(task=operator, execution_date=dttm) - search_query = ti.xcom_pull(task_ids=operator.task_id, key='search_query') + search_query = XCom.get_one( + task_id=operator.task_id, dag_id=operator.dag_id, execution_date=dttm, key='search_query' + ) return f'http://google.com/custom_base_link?search={search_query}' diff --git a/tests/ti_deps/deps/test_dagrun_exists_dep.py b/tests/ti_deps/deps/test_dagrun_exists_dep.py index 4ec17c87be29c..2e36acc58b228 100644 --- a/tests/ti_deps/deps/test_dagrun_exists_dep.py +++ b/tests/ti_deps/deps/test_dagrun_exists_dep.py @@ -32,7 +32,8 @@ def test_dagrun_doesnt_exist(self, mock_dagrun_find): Task instances without dagruns should fail this dep """ dag = DAG('test_dag', max_active_runs=2) - ti = Mock(task=Mock(dag=dag), get_dagrun=Mock(return_value=None)) + dagrun = DagRun(state=State.NONE) + ti = Mock(task=Mock(dag=dag), get_dagrun=Mock(return_value=dagrun)) assert not DagrunRunningDep().is_met(ti=ti) def test_dagrun_exists(self): diff --git a/tests/ti_deps/deps/test_dagrun_id_dep.py b/tests/ti_deps/deps/test_dagrun_id_dep.py index 1192719bc8b8f..e416dd53c80fc 100644 --- a/tests/ti_deps/deps/test_dagrun_id_dep.py +++ b/tests/ti_deps/deps/test_dagrun_id_dep.py @@ -49,10 +49,3 @@ def test_dagrun_id_is_not_backfill(self): dagrun.run_id = None ti = Mock(get_dagrun=Mock(return_value=dagrun)) assert DagrunIdDep().is_met(ti=ti) - - def test_dagrun_is_none(self): - """ - Task instances which don't yet have an associated dagrun. - """ - ti = Mock(get_dagrun=Mock(return_value=None)) - assert DagrunIdDep().is_met(ti=ti) diff --git a/tests/ti_deps/deps/test_not_previously_skipped_dep.py b/tests/ti_deps/deps/test_not_previously_skipped_dep.py index a600df28ea0a5..411260c9b076d 100644 --- a/tests/ti_deps/deps/test_not_previously_skipped_dep.py +++ b/tests/ti_deps/deps/test_not_previously_skipped_dep.py @@ -17,113 +17,144 @@ # under the License. import pendulum +import pytest -from airflow.models import DAG, DagRun, TaskInstance +from airflow.models import DagRun, TaskInstance from airflow.operators.dummy import DummyOperator from airflow.operators.python import BranchPythonOperator from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep -from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.types import DagRunType -def test_no_parent(): +@pytest.fixture(autouse=True, scope="function") +def clean_db(session): + yield + session.query(DagRun).delete() + session.query(TaskInstance).delete() + + +def test_no_parent(session, dag_maker): """ A simple DAG with a single task. NotPreviouslySkippedDep is met. """ start_date = pendulum.datetime(2020, 1, 1) - dag = DAG("test_test_no_parent_dag", schedule_interval=None, start_date=start_date) - op1 = DummyOperator(task_id="op1", dag=dag) + with dag_maker( + "test_test_no_parent_dag", + schedule_interval=None, + start_date=start_date, + session=session, + ): + op1 = DummyOperator(task_id="op1") - ti1 = TaskInstance(op1, start_date) + (ti1,) = dag_maker.create_dagrun(execution_date=start_date).task_instances + ti1.refresh_from_task(op1) - with create_session() as session: - dep = NotPreviouslySkippedDep() - assert len(list(dep.get_dep_statuses(ti1, session, DepContext()))) == 0 - assert dep.is_met(ti1, session) - assert ti1.state != State.SKIPPED + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti1, session, DepContext()))) == 0 + assert dep.is_met(ti1, session) + assert ti1.state != State.SKIPPED -def test_no_skipmixin_parent(): +def test_no_skipmixin_parent(session, dag_maker): """ A simple DAG with no branching. Both op1 and op2 are DummyOperator. NotPreviouslySkippedDep is met. """ start_date = pendulum.datetime(2020, 1, 1) - dag = DAG("test_no_skipmixin_parent_dag", schedule_interval=None, start_date=start_date) - op1 = DummyOperator(task_id="op1", dag=dag) - op2 = DummyOperator(task_id="op2", dag=dag) - op1 >> op2 - - ti2 = TaskInstance(op2, start_date) - - with create_session() as session: - dep = NotPreviouslySkippedDep() - assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 - assert dep.is_met(ti2, session) - assert ti2.state != State.SKIPPED - - -def test_parent_follow_branch(): + with dag_maker( + "test_no_skipmixin_parent_dag", + schedule_interval=None, + start_date=start_date, + session=session, + ): + op1 = DummyOperator(task_id="op1") + op2 = DummyOperator(task_id="op2") + op1 >> op2 + + _, ti2 = dag_maker.create_dagrun().task_instances + ti2.refresh_from_task(op2) + + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 + assert dep.is_met(ti2, session) + assert ti2.state != State.SKIPPED + + +def test_parent_follow_branch(session, dag_maker): """ A simple DAG with a BranchPythonOperator that follows op2. NotPreviouslySkippedDep is met. """ start_date = pendulum.datetime(2020, 1, 1) - dag = DAG("test_parent_follow_branch_dag", schedule_interval=None, start_date=start_date) - dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=start_date) - op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op2", dag=dag) - op2 = DummyOperator(task_id="op2", dag=dag) - op1 >> op2 - TaskInstance(op1, start_date).run() - ti2 = TaskInstance(op2, start_date) - - with create_session() as session: - dep = NotPreviouslySkippedDep() - assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 - assert dep.is_met(ti2, session) - assert ti2.state != State.SKIPPED - - -def test_parent_skip_branch(): + with dag_maker( + "test_parent_follow_branch_dag", + schedule_interval=None, + start_date=start_date, + session=session, + ): + op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op2") + op2 = DummyOperator(task_id="op2") + op1 >> op2 + + dagrun = dag_maker.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING) + ti, ti2 = dagrun.task_instances + ti.run() + + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 + assert dep.is_met(ti2, session) + assert ti2.state != State.SKIPPED + + +def test_parent_skip_branch(session, dag_maker): """ A simple DAG with a BranchPythonOperator that does not follow op2. NotPreviouslySkippedDep is not met. """ - with create_session() as session: - session.query(DagRun).delete() - session.query(TaskInstance).delete() - start_date = pendulum.datetime(2020, 1, 1) - dag = DAG("test_parent_skip_branch_dag", schedule_interval=None, start_date=start_date) - dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=start_date) - op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag) - op2 = DummyOperator(task_id="op2", dag=dag) - op3 = DummyOperator(task_id="op3", dag=dag) + start_date = pendulum.datetime(2020, 1, 1) + with dag_maker( + "test_parent_skip_branch_dag", + schedule_interval=None, + start_date=start_date, + session=session, + ): + op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3") + op2 = DummyOperator(task_id="op2") + op3 = DummyOperator(task_id="op3") op1 >> [op2, op3] - TaskInstance(op1, start_date).run() - ti2 = TaskInstance(op2, start_date) - dep = NotPreviouslySkippedDep() - assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 1 - session.commit() - assert not dep.is_met(ti2, session) - assert ti2.state == State.SKIPPED + tis = { + ti.task_id: ti + for ti in dag_maker.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING).task_instances + } + tis["op1"].run() + + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(tis["op2"], session, DepContext()))) == 1 + assert not dep.is_met(tis["op2"], session) + assert tis["op2"].state == State.SKIPPED -def test_parent_not_executed(): +def test_parent_not_executed(session, dag_maker): """ A simple DAG with a BranchPythonOperator that does not follow op2. Parent task is not yet executed (no xcom data). NotPreviouslySkippedDep is met (no decision). """ start_date = pendulum.datetime(2020, 1, 1) - dag = DAG("test_parent_not_executed_dag", schedule_interval=None, start_date=start_date) - op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag) - op2 = DummyOperator(task_id="op2", dag=dag) - op3 = DummyOperator(task_id="op3", dag=dag) - op1 >> [op2, op3] - - ti2 = TaskInstance(op2, start_date) - - with create_session() as session: - dep = NotPreviouslySkippedDep() - assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 - assert dep.is_met(ti2, session) - assert ti2.state == State.NONE + with dag_maker( + "test_parent_not_executed_dag", + schedule_interval=None, + start_date=start_date, + session=session, + ): + op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3") + op2 = DummyOperator(task_id="op2") + op3 = DummyOperator(task_id="op3") + op1 >> [op2, op3] + + _, ti2, _ = dag_maker.create_dagrun().task_instances + ti2.refresh_from_task(op2) + + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 + assert dep.is_met(ti2, session) + assert ti2.state == State.NONE diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py index 053c9522f19e1..470166db21c8d 100644 --- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py +++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py @@ -32,14 +32,14 @@ class TestNotInReschedulePeriodDep(unittest.TestCase): def _get_task_instance(self, state): dag = DAG('test_dag') task = Mock(dag=dag) - ti = TaskInstance(task=task, state=state, execution_date=None) + ti = TaskInstance(task=task, state=state, run_id=None) return ti def _get_task_reschedule(self, reschedule_date): task = Mock(dag_id='test_dag', task_id='test_task') reschedule = TaskReschedule( task=task, - execution_date=None, + run_id=None, try_number=None, start_date=reschedule_date, end_date=reschedule_date, diff --git a/tests/ti_deps/deps/test_runnable_exec_date_dep.py b/tests/ti_deps/deps/test_runnable_exec_date_dep.py index ac81bf6c69f89..5a5ca3b669050 100644 --- a/tests/ti_deps/deps/test_runnable_exec_date_dep.py +++ b/tests/ti_deps/deps/test_runnable_exec_date_dep.py @@ -17,19 +17,24 @@ # under the License. -import unittest from unittest.mock import Mock, patch import pytest from freezegun import freeze_time from airflow import settings -from airflow.models import DAG, TaskInstance -from airflow.operators.dummy import DummyOperator +from airflow.models import DagRun, TaskInstance from airflow.ti_deps.deps.runnable_exec_date_dep import RunnableExecDateDep from airflow.utils.timezone import datetime +@pytest.fixture(autouse=True, scope="function") +def clean_db(session): + yield + session.query(DagRun).delete() + session.query(TaskInstance).delete() + + @freeze_time('2016-11-01') @pytest.mark.parametrize( "allow_trigger_in_future,schedule_interval,execution_date,is_met", @@ -42,50 +47,53 @@ (False, None, datetime(2016, 11, 1), True), ], ) -def test_exec_date_dep(allow_trigger_in_future, schedule_interval, execution_date, is_met): +def test_exec_date_dep( + dag_maker, + session, + create_dummy_dag, + allow_trigger_in_future, + schedule_interval, + execution_date, + is_met, +): """ If the dag's execution date is in the future but (allow_trigger_in_future=False or not schedule_interval) this dep should fail """ - with patch.object(settings, 'ALLOW_FUTURE_EXEC_DATES', allow_trigger_in_future): - dag = DAG( + create_dummy_dag( 'test_localtaskjob_heartbeat', start_date=datetime(2015, 1, 1), end_date=datetime(2016, 11, 5), schedule_interval=schedule_interval, + session=session, ) - - with dag: - op1 = DummyOperator(task_id='op1') - - ti = TaskInstance(task=op1, execution_date=execution_date) + (ti,) = dag_maker.create_dagrun(execution_date=execution_date).task_instances assert RunnableExecDateDep().is_met(ti=ti) == is_met -class TestRunnableExecDateDep(unittest.TestCase): +@freeze_time('2016-01-01') +def test_exec_date_after_end_date(session, dag_maker, create_dummy_dag): + """ + If the dag's execution date is in the future this dep should fail + """ + create_dummy_dag( + 'test_localtaskjob_heartbeat', + start_date=datetime(2015, 1, 1), + end_date=datetime(2016, 11, 5), + schedule_interval=None, + session=session, + ) + (ti,) = dag_maker.create_dagrun(execution_date=datetime(2016, 11, 2)).task_instances + assert not RunnableExecDateDep().is_met(ti=ti) + + +class TestRunnableExecDateDep: def _get_task_instance(self, execution_date, dag_end_date=None, task_end_date=None): dag = Mock(end_date=dag_end_date) + dagrun = Mock(execution_date=execution_date) task = Mock(dag=dag, end_date=task_end_date) - return TaskInstance(task=task, execution_date=execution_date) - - @freeze_time('2016-01-01') - def test_exec_date_after_end_date(self): - """ - If the dag's execution date is in the future this dep should fail - """ - dag = DAG( - 'test_localtaskjob_heartbeat', - start_date=datetime(2015, 1, 1), - end_date=datetime(2016, 11, 5), - schedule_interval=None, - ) - - with dag: - op1 = DummyOperator(task_id='op1') - - ti = TaskInstance(task=op1, execution_date=datetime(2016, 11, 2)) - assert not RunnableExecDateDep().is_met(ti=ti) + return Mock(task=task, get_dagrun=Mock(return_value=dagrun)) def test_exec_date_after_task_end_date(self): """ diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index fee1e1fff96b7..bbdb84679cb1b 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -15,14 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - - -import unittest from datetime import datetime from unittest.mock import Mock +import pytest + from airflow import settings -from airflow.models import DAG, TaskInstance +from airflow.models import DAG from airflow.models.baseoperator import BaseOperator from airflow.operators.dummy import DummyOperator from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep @@ -34,32 +33,43 @@ from tests.test_utils.db import clear_db_runs -class TestTriggerRuleDep(unittest.TestCase): - def _get_task_instance(self, trigger_rule=TriggerRule.ALL_SUCCESS, state=None, upstream_task_ids=None): - task = BaseOperator(task_id='test_task', trigger_rule=trigger_rule, start_date=datetime(2015, 1, 1)) - if upstream_task_ids: - task._upstream_task_ids.update(upstream_task_ids) - return TaskInstance(task=task, state=state, execution_date=task.start_date) +@pytest.fixture +def get_task_instance(session, dag_maker): + def _get_task_instance(trigger_rule=TriggerRule.ALL_SUCCESS, state=None, upstream_task_ids=None): + with dag_maker(session=session): + task = BaseOperator( + task_id='test_task', trigger_rule=trigger_rule, start_date=datetime(2015, 1, 1) + ) + if upstream_task_ids: + task._upstream_task_ids.update(upstream_task_ids) + dr = dag_maker.create_dagrun() + ti = dr.task_instances[0] + ti.task = task + return ti + + return _get_task_instance - def test_no_upstream_tasks(self): + +class TestTriggerRuleDep: + def test_no_upstream_tasks(self, get_task_instance): """ If the TI has no upstream TIs then there is nothing to check and the dep is passed """ - ti = self._get_task_instance(TriggerRule.ALL_DONE, State.UP_FOR_RETRY) + ti = get_task_instance(TriggerRule.ALL_DONE, State.UP_FOR_RETRY) assert TriggerRuleDep().is_met(ti=ti) - def test_always_tr(self): + def test_always_tr(self, get_task_instance): """ The always trigger rule should always pass this dep """ - ti = self._get_task_instance(TriggerRule.ALWAYS, State.UP_FOR_RETRY) + ti = get_task_instance(TriggerRule.ALWAYS, State.UP_FOR_RETRY) assert TriggerRuleDep().is_met(ti=ti) - def test_one_success_tr_success(self): + def test_one_success_tr_success(self, get_task_instance): """ One-success trigger rule success """ - ti = self._get_task_instance(TriggerRule.ONE_SUCCESS, State.UP_FOR_RETRY) + ti = get_task_instance(TriggerRule.ONE_SUCCESS, State.UP_FOR_RETRY) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -74,11 +84,11 @@ def test_one_success_tr_success(self): ) assert len(dep_statuses) == 0 - def test_one_success_tr_failure(self): + def test_one_success_tr_failure(self, get_task_instance): """ One-success trigger rule failure """ - ti = self._get_task_instance(TriggerRule.ONE_SUCCESS) + ti = get_task_instance(TriggerRule.ONE_SUCCESS) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -94,11 +104,11 @@ def test_one_success_tr_failure(self): assert len(dep_statuses) == 1 assert not dep_statuses[0].passed - def test_one_failure_tr_failure(self): + def test_one_failure_tr_failure(self, get_task_instance): """ One-failure trigger rule failure """ - ti = self._get_task_instance(TriggerRule.ONE_FAILED) + ti = get_task_instance(TriggerRule.ONE_FAILED) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -114,11 +124,11 @@ def test_one_failure_tr_failure(self): assert len(dep_statuses) == 1 assert not dep_statuses[0].passed - def test_one_failure_tr_success(self): + def test_one_failure_tr_success(self, get_task_instance): """ One-failure trigger rule success """ - ti = self._get_task_instance(TriggerRule.ONE_FAILED) + ti = get_task_instance(TriggerRule.ONE_FAILED) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -147,11 +157,11 @@ def test_one_failure_tr_success(self): ) assert len(dep_statuses) == 0 - def test_all_success_tr_success(self): + def test_all_success_tr_success(self, get_task_instance): """ All-success trigger rule success """ - ti = self._get_task_instance(TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID"]) + ti = get_task_instance(TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID"]) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -166,13 +176,11 @@ def test_all_success_tr_success(self): ) assert len(dep_statuses) == 0 - def test_all_success_tr_failure(self): + def test_all_success_tr_failure(self, get_task_instance): """ All-success trigger rule failure """ - ti = self._get_task_instance( - TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] - ) + ti = get_task_instance(TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -188,13 +196,11 @@ def test_all_success_tr_failure(self): assert len(dep_statuses) == 1 assert not dep_statuses[0].passed - def test_all_success_tr_skip(self): + def test_all_success_tr_skip(self, get_task_instance): """ All-success trigger rule fails when some upstream tasks are skipped. """ - ti = self._get_task_instance( - TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] - ) + ti = get_task_instance(TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -210,14 +216,12 @@ def test_all_success_tr_skip(self): assert len(dep_statuses) == 1 assert not dep_statuses[0].passed - def test_all_success_tr_skip_flag_upstream(self): + def test_all_success_tr_skip_flag_upstream(self, get_task_instance): """ All-success trigger rule fails when some upstream tasks are skipped. The state of the ti should be set to SKIPPED when flag_upstream_failed is True. """ - ti = self._get_task_instance( - TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] - ) + ti = get_task_instance(TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -234,13 +238,11 @@ def test_all_success_tr_skip_flag_upstream(self): assert not dep_statuses[0].passed assert ti.state == State.SKIPPED - def test_none_failed_tr_success(self): + def test_none_failed_tr_success(self, get_task_instance): """ All success including skip trigger rule success """ - ti = self._get_task_instance( - TriggerRule.NONE_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] - ) + ti = get_task_instance(TriggerRule.NONE_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -255,13 +257,11 @@ def test_none_failed_tr_success(self): ) assert len(dep_statuses) == 0 - def test_none_failed_tr_skipped(self): + def test_none_failed_tr_skipped(self, get_task_instance): """ All success including all upstream skips trigger rule success """ - ti = self._get_task_instance( - TriggerRule.NONE_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] - ) + ti = get_task_instance(TriggerRule.NONE_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -277,11 +277,11 @@ def test_none_failed_tr_skipped(self): assert len(dep_statuses) == 0 assert ti.state == State.NONE - def test_none_failed_tr_failure(self): + def test_none_failed_tr_failure(self, get_task_instance): """ All success including skip trigger rule failure """ - ti = self._get_task_instance( + ti = get_task_instance( TriggerRule.NONE_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"] ) dep_statuses = tuple( @@ -299,11 +299,11 @@ def test_none_failed_tr_failure(self): assert len(dep_statuses) == 1 assert not dep_statuses[0].passed - def test_none_failed_min_one_success_tr_success(self): + def test_none_failed_min_one_success_tr_success(self, get_task_instance): """ All success including skip trigger rule success """ - ti = self._get_task_instance( + ti = get_task_instance( TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] ) dep_statuses = tuple( @@ -320,11 +320,11 @@ def test_none_failed_min_one_success_tr_success(self): ) assert len(dep_statuses) == 0 - def test_none_failed_min_one_success_tr_skipped(self): + def test_none_failed_min_one_success_tr_skipped(self, get_task_instance): """ All success including all upstream skips trigger rule success """ - ti = self._get_task_instance( + ti = get_task_instance( TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] ) dep_statuses = tuple( @@ -342,11 +342,11 @@ def test_none_failed_min_one_success_tr_skipped(self): assert len(dep_statuses) == 0 assert ti.state == State.SKIPPED - def test_none_failed_min_one_success_tr_failure(self): + def test_none_failed_min_one_success_tr_failure(self, session, get_task_instance): """ All success including skip trigger rule failure """ - ti = self._get_task_instance( + ti = get_task_instance( TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"], ) @@ -365,13 +365,11 @@ def test_none_failed_min_one_success_tr_failure(self): assert len(dep_statuses) == 1 assert not dep_statuses[0].passed - def test_all_failed_tr_success(self): + def test_all_failed_tr_success(self, get_task_instance): """ All-failed trigger rule success """ - ti = self._get_task_instance( - TriggerRule.ALL_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] - ) + ti = get_task_instance(TriggerRule.ALL_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -386,13 +384,11 @@ def test_all_failed_tr_success(self): ) assert len(dep_statuses) == 0 - def test_all_failed_tr_failure(self): + def test_all_failed_tr_failure(self, get_task_instance): """ All-failed trigger rule failure """ - ti = self._get_task_instance( - TriggerRule.ALL_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] - ) + ti = get_task_instance(TriggerRule.ALL_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -408,13 +404,11 @@ def test_all_failed_tr_failure(self): assert len(dep_statuses) == 1 assert not dep_statuses[0].passed - def test_all_done_tr_success(self): + def test_all_done_tr_success(self, get_task_instance): """ All-done trigger rule success """ - ti = self._get_task_instance( - TriggerRule.ALL_DONE, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] - ) + ti = get_task_instance(TriggerRule.ALL_DONE, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -429,13 +423,11 @@ def test_all_done_tr_success(self): ) assert len(dep_statuses) == 0 - def test_all_done_tr_failure(self): + def test_all_done_tr_failure(self, get_task_instance): """ All-done trigger rule failure """ - ti = self._get_task_instance( - TriggerRule.ALL_DONE, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] - ) + ti = get_task_instance(TriggerRule.ALL_DONE, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( ti=ti, @@ -451,12 +443,11 @@ def test_all_done_tr_failure(self): assert len(dep_statuses) == 1 assert not dep_statuses[0].passed - def test_none_skipped_tr_success(self): + def test_none_skipped_tr_success(self, get_task_instance): """ None-skipped trigger rule success """ - - ti = self._get_task_instance( + ti = get_task_instance( TriggerRule.NONE_SKIPPED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"] ) with create_session() as session: @@ -489,13 +480,11 @@ def test_none_skipped_tr_success(self): ) assert len(dep_statuses) == 0 - def test_none_skipped_tr_failure(self): + def test_none_skipped_tr_failure(self, get_task_instance): """ None-skipped trigger rule failure """ - ti = self._get_task_instance( - TriggerRule.NONE_SKIPPED, upstream_task_ids=["FakeTaskID", "SkippedTaskID"] - ) + ti = get_task_instance(TriggerRule.NONE_SKIPPED, upstream_task_ids=["FakeTaskID", "SkippedTaskID"]) with create_session() as session: dep_statuses = tuple( @@ -545,11 +534,11 @@ def test_none_skipped_tr_failure(self): assert len(dep_statuses) == 1 assert not dep_statuses[0].passed - def test_unknown_tr(self): + def test_unknown_tr(self, get_task_instance): """ Unknown trigger rules should cause this dep to fail """ - ti = self._get_task_instance() + ti = get_task_instance() ti.task.trigger_rule = "Unknown Trigger Rule" dep_statuses = tuple( TriggerRuleDep()._evaluate_trigger_rule( @@ -595,11 +584,16 @@ def test_get_states_count_upstream_ti(self): run_id='test_dagrun_with_pre_tis', state=State.RUNNING, execution_date=now, start_date=now ) - ti_op1 = TaskInstance(task=dag.get_task(op1.task_id), execution_date=dr.execution_date) - ti_op2 = TaskInstance(task=dag.get_task(op2.task_id), execution_date=dr.execution_date) - ti_op3 = TaskInstance(task=dag.get_task(op3.task_id), execution_date=dr.execution_date) - ti_op4 = TaskInstance(task=dag.get_task(op4.task_id), execution_date=dr.execution_date) - ti_op5 = TaskInstance(task=dag.get_task(op5.task_id), execution_date=dr.execution_date) + ti_op1 = dr.get_task_instance(op1.task_id, session) + ti_op2 = dr.get_task_instance(op2.task_id, session) + ti_op3 = dr.get_task_instance(op3.task_id, session) + ti_op4 = dr.get_task_instance(op4.task_id, session) + ti_op5 = dr.get_task_instance(op5.task_id, session) + ti_op1.task = op1 + ti_op2.task = op2 + ti_op3.task = op3 + ti_op4.task = op4 + ti_op5.task = op5 ti_op1.set_state(state=State.SUCCESS, session=session) ti_op2.set_state(state=State.FAILED, session=session) @@ -610,7 +604,7 @@ def test_get_states_count_upstream_ti(self): session.commit() # check handling with cases that tasks are triggered from backfill with no finished tasks - finished_tasks = DepContext().ensure_finished_tasks(ti_op2.task.dag, ti_op2.execution_date, session) + finished_tasks = DepContext().ensure_finished_tasks(ti_op2.dag_run, session) assert get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2) == (1, 0, 0, 0, 1) finished_tasks = dr.get_task_instances(state=State.finished, session=session) assert get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4) == (1, 0, 1, 0, 2) diff --git a/tests/utils/log/test_log_reader.py b/tests/utils/log/test_log_reader.py index 301eed71bd06e..67e7d5a162521 100644 --- a/tests/utils/log/test_log_reader.py +++ b/tests/utils/log/test_log_reader.py @@ -18,87 +18,87 @@ import copy import logging import os -import shutil import sys import tempfile -import unittest from unittest import mock -from airflow import DAG, settings +import pytest + +from airflow import settings from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG -from airflow.models import TaskInstance -from airflow.operators.dummy import DummyOperator from airflow.utils import timezone from airflow.utils.log.log_reader import TaskLogReader from airflow.utils.log.logging_mixin import ExternalLoggingMixin -from airflow.utils.session import create_session +from airflow.utils.state import TaskInstanceState from tests.test_utils.config import conf_vars -from tests.test_utils.db import clear_db_runs +from tests.test_utils.db import clear_db_dags, clear_db_runs -class TestLogView(unittest.TestCase): +class TestLogView: DAG_ID = "dag_log_reader" TASK_ID = "task_log_reader" DEFAULT_DATE = timezone.datetime(2017, 9, 1) - def setUp(self): - self.maxDiff = None - - # Make sure that the configure_logging is not cached - self.old_modules = dict(sys.modules) - - self.settings_folder = tempfile.mkdtemp() - self.log_dir = tempfile.mkdtemp() - - self._configure_loggers() - self._prepare_db() - self._prepare_log_files() - - def _prepare_log_files(self): - dir_path = f"{self.log_dir}/{self.DAG_ID}/{self.TASK_ID}/2017-09-01T00.00.00+00.00/" - os.makedirs(dir_path) - for try_number in range(1, 4): - with open(f"{dir_path}/{try_number}.log", "w+") as file: - file.write(f"try_number={try_number}.\n") - file.flush() - - def _prepare_db(self): - dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) - dag.sync_to_db() - with create_session() as session: - op = DummyOperator(task_id=self.TASK_ID, dag=dag) - self.ti = TaskInstance(task=op, execution_date=self.DEFAULT_DATE) - self.ti.try_number = 3 - - session.merge(self.ti) + @pytest.fixture(autouse=True) + def log_dir(self): + with tempfile.TemporaryDirectory() as log_dir: + self.log_dir = log_dir + yield log_dir + del self.log_dir + + @pytest.fixture(autouse=True) + def settings_folder(self): + old_modules = dict(sys.modules) + with tempfile.TemporaryDirectory() as settings_folder: + self.settings_folder = settings_folder + sys.path.append(settings_folder) + yield settings_folder + sys.path.remove(settings_folder) + # Remove any new modules imported during the test run. This lets us + # import the same source files for more than one test. + for mod in [m for m in sys.modules if m not in old_modules]: + del sys.modules[mod] + del self.settings_folder - def _configure_loggers(self): + @pytest.fixture(autouse=True) + def configure_loggers(self, log_dir, settings_folder): logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG) - logging_config["handlers"]["task"]["base_log_folder"] = self.log_dir + logging_config["handlers"]["task"]["base_log_folder"] = log_dir logging_config["handlers"]["task"][ "filename_template" ] = "{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(':', '.') }}/{{ try_number }}.log" - settings_file = os.path.join(self.settings_folder, "airflow_local_settings.py") + settings_file = os.path.join(settings_folder, "airflow_local_settings.py") with open(settings_file, "w") as handle: new_logging_file = f"LOGGING_CONFIG = {logging_config}" handle.writelines(new_logging_file) - sys.path.append(self.settings_folder) with conf_vars({("logging", "logging_config_class"): "airflow_local_settings.LOGGING_CONFIG"}): settings.configure_logging() - - def tearDown(self): + yield logging.config.dictConfig(DEFAULT_LOGGING_CONFIG) - clear_db_runs() - # Remove any new modules imported during the test run. This lets us - # import the same source files for more than one test. - for mod in [m for m in sys.modules if m not in self.old_modules]: - del sys.modules[mod] - - sys.path.remove(self.settings_folder) - shutil.rmtree(self.settings_folder) - shutil.rmtree(self.log_dir) - super().tearDown() + @pytest.fixture(autouse=True) + def prepare_log_files(self, log_dir): + dir_path = f"{log_dir}/{self.DAG_ID}/{self.TASK_ID}/2017-09-01T00.00.00+00.00/" + os.makedirs(dir_path) + for try_number in range(1, 4): + with open(f"{dir_path}/{try_number}.log", "w+") as f: + f.write(f"try_number={try_number}.\n") + f.flush() + + @pytest.fixture(autouse=True) + def prepare_db(self, session, create_task_instance): + ti = create_task_instance( + dag_id=self.DAG_ID, + task_id=self.TASK_ID, + start_date=self.DEFAULT_DATE, + execution_date=self.DEFAULT_DATE, + state=TaskInstanceState.RUNNING, + ) + ti.try_number = 3 + self.ti = ti + yield + clear_db_runs() + clear_db_dags() def test_test_read_log_chunks_should_read_one_try(self): task_log_reader = TaskLogReader() diff --git a/tests/utils/test_dot_renderer.py b/tests/utils/test_dot_renderer.py index ca3ea01794135..3d9cb7a44b593 100644 --- a/tests/utils/test_dot_renderer.py +++ b/tests/utils/test_dot_renderer.py @@ -16,33 +16,34 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import datetime -import unittest from unittest import mock -from airflow.models import TaskInstance from airflow.models.dag import DAG from airflow.operators.bash import BashOperator from airflow.operators.dummy import DummyOperator from airflow.operators.python import PythonOperator -from airflow.utils import dot_renderer +from airflow.utils import dot_renderer, timezone from airflow.utils.state import State from airflow.utils.task_group import TaskGroup +from tests.test_utils.db import clear_db_dags -START_DATE = datetime.datetime.now() +START_DATE = timezone.utcnow() -class TestDotRenderer(unittest.TestCase): - def test_should_render_dag(self): +class TestDotRenderer: + def setup_class(self): + clear_db_dags() - dag = DAG(dag_id="DAG_ID") - task_1 = BashOperator(dag=dag, start_date=START_DATE, task_id="first", bash_command="echo 1") - task_2 = BashOperator(dag=dag, start_date=START_DATE, task_id="second", bash_command="echo 1") - task_3 = PythonOperator( - dag=dag, start_date=START_DATE, task_id="third", python_callable=mock.MagicMock() - ) - task_1 >> task_2 - task_1 >> task_3 + def teardown_method(self): + clear_db_dags() + + def test_should_render_dag(self): + with DAG(dag_id="DAG_ID") as dag: + task_1 = BashOperator(start_date=START_DATE, task_id="first", bash_command="echo 1") + task_2 = BashOperator(start_date=START_DATE, task_id="second", bash_command="echo 1") + task_3 = PythonOperator(start_date=START_DATE, task_id="third", python_callable=mock.MagicMock()) + task_1 >> task_2 + task_1 >> task_3 dot = dot_renderer.render_dag(dag) source = dot.source @@ -56,21 +57,20 @@ def test_should_render_dag(self): assert 'fillcolor="#f0ede4"' in source assert 'fillcolor="#f0ede4"' in source - def test_should_render_dag_with_task_instances(self): - dag = DAG(dag_id="DAG_ID") - task_1 = BashOperator(dag=dag, start_date=START_DATE, task_id="first", bash_command="echo 1") - task_2 = BashOperator(dag=dag, start_date=START_DATE, task_id="second", bash_command="echo 1") - task_3 = PythonOperator( - dag=dag, start_date=START_DATE, task_id="third", python_callable=mock.MagicMock() - ) - task_1 >> task_2 - task_1 >> task_3 - tis = [ - TaskInstance(task_1, execution_date=START_DATE, state=State.SCHEDULED), - TaskInstance(task_2, execution_date=START_DATE, state=State.SUCCESS), - TaskInstance(task_3, execution_date=START_DATE, state=State.RUNNING), - ] - dot = dot_renderer.render_dag(dag, tis=tis) + def test_should_render_dag_with_task_instances(self, session, dag_maker): + with dag_maker(dag_id="DAG_ID", session=session) as dag: + task_1 = BashOperator(start_date=START_DATE, task_id="first", bash_command="echo 1") + task_2 = BashOperator(start_date=START_DATE, task_id="second", bash_command="echo 1") + task_3 = PythonOperator(start_date=START_DATE, task_id="third", python_callable=mock.MagicMock()) + task_1 >> task_2 + task_1 >> task_3 + + tis = {ti.task_id: ti for ti in dag_maker.create_dagrun(execution_date=START_DATE).task_instances} + tis["first"].state = State.SCHEDULED + tis["second"].state = State.SUCCESS + tis["third"].state = State.RUNNING + + dot = dot_renderer.render_dag(dag, tis=tis.values()) source = dot.source # Should render DAG title assert "label=DAG_ID" in source @@ -85,21 +85,20 @@ def test_should_render_dag_with_task_instances(self): 'third [color=black fillcolor=lime label=third shape=rectangle style="filled,rounded"]' in source ) - def test_should_render_dag_orientation(self): + def test_should_render_dag_orientation(self, session, dag_maker): orientation = "TB" - dag = DAG(dag_id="DAG_ID", orientation=orientation) - task_1 = BashOperator(dag=dag, start_date=START_DATE, task_id="first", bash_command="echo 1") - task_2 = BashOperator(dag=dag, start_date=START_DATE, task_id="second", bash_command="echo 1") - task_3 = PythonOperator( - dag=dag, start_date=START_DATE, task_id="third", python_callable=mock.MagicMock() - ) - task_1 >> task_2 - task_1 >> task_3 - tis = [ - TaskInstance(task_1, execution_date=START_DATE, state=State.SCHEDULED), - TaskInstance(task_2, execution_date=START_DATE, state=State.SUCCESS), - TaskInstance(task_3, execution_date=START_DATE, state=State.RUNNING), - ] + with dag_maker(dag_id="DAG_ID", orientation=orientation, session=session) as dag: + task_1 = BashOperator(start_date=START_DATE, task_id="first", bash_command="echo 1") + task_2 = BashOperator(start_date=START_DATE, task_id="second", bash_command="echo 1") + task_3 = PythonOperator(start_date=START_DATE, task_id="third", python_callable=mock.MagicMock()) + task_1 >> task_2 + task_1 >> task_3 + + tis = dag_maker.create_dagrun(execution_date=START_DATE).task_instances + tis[0].state = State.SCHEDULED + tis[1].state = State.SUCCESS + tis[2].state = State.RUNNING + dot = dot_renderer.render_dag(dag, tis=tis) source = dot.source # Should render DAG title with orientation diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index a2aa10b14e707..1bf94afe88bf4 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -16,32 +16,34 @@ # specific language governing permissions and limitations # under the License. import re -import unittest -from datetime import datetime import pytest -from parameterized import parameterized from airflow import AirflowException -from airflow.models import TaskInstance -from airflow.models.dag import DAG -from airflow.operators.dummy import DummyOperator -from airflow.utils import helpers +from airflow.utils import helpers, timezone from airflow.utils.helpers import build_airflow_url_with_query, merge_dicts, validate_group_key, validate_key from tests.test_utils.config import conf_vars +from tests.test_utils.db import clear_db_dags, clear_db_runs -class TestHelpers(unittest.TestCase): - def test_render_log_filename(self): +@pytest.fixture() +def clear_db(): + clear_db_runs() + clear_db_dags() + yield + clear_db_runs() + clear_db_dags() + + +class TestHelpers: + @pytest.mark.usefixtures("clear_db") + def test_render_log_filename(self, create_task_instance): try_number = 1 dag_id = 'test_render_log_filename_dag' task_id = 'test_render_log_filename_task' - execution_date = datetime(2016, 1, 1) - - dag = DAG(dag_id, start_date=execution_date) - task = DummyOperator(task_id=task_id, dag=dag) - ti = TaskInstance(task=task, execution_date=execution_date) + execution_date = timezone.datetime(2016, 1, 1) + ti = create_task_instance(dag_id=dag_id, task_id=task_id, execution_date=execution_date) filename_template = "{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log" ts = ti.get_template_context()['ts'] @@ -157,7 +159,8 @@ def test_build_airflow_url_with_query(self): with cached_app(testing=True).test_request_context(): assert build_airflow_url_with_query(query) == expected_url - @parameterized.expand( + @pytest.mark.parametrize( + "key_id, message, exception", [ (3, "The key has to be a string and is :3", TypeError), (None, "The key has to be a string and is :None", TypeError), @@ -178,7 +181,7 @@ def test_build_airflow_url_with_query(self): AirflowException, ), (' ' * 251, "The key has to be less than 250 characters", AirflowException), - ] + ], ) def test_validate_key(self, key_id, message, exception): if message: @@ -187,7 +190,8 @@ def test_validate_key(self, key_id, message, exception): else: validate_key(key_id) - @parameterized.expand( + @pytest.mark.parametrize( + "key_id, message, exception", [ (3, "The key has to be a string and is :3", TypeError), (None, "The key has to be a string and is :None", TypeError), @@ -218,7 +222,7 @@ def test_validate_key(self, key_id, message, exception): AirflowException, ), (' ' * 201, "The key has to be less than 200 characters", AirflowException), - ] + ], ) def test_validate_group_key(self, key_id, message, exception): if message: diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index fad5f8b58714a..6d5403c3d9f5f 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -20,11 +20,11 @@ import logging.config import os import re -import unittest + +import pytest from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG from airflow.models import DAG, DagRun, TaskInstance -from airflow.operators.dummy import DummyOperator from airflow.operators.python import PythonOperator from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import set_context @@ -38,22 +38,20 @@ FILE_TASK_HANDLER = 'task' -class TestFileTaskLogHandler(unittest.TestCase): +class TestFileTaskLogHandler: def clean_up(self): with create_session() as session: session.query(DagRun).delete() session.query(TaskInstance).delete() - def setUp(self): - super().setUp() + def setup_method(self): logging.config.dictConfig(DEFAULT_LOGGING_CONFIG) logging.root.disabled = False self.clean_up() # We use file task handler by default. - def tearDown(self): + def teardown_method(self): self.clean_up() - super().tearDown() def test_default_task_logging_setup(self): # file task handler is used by default. @@ -68,13 +66,17 @@ def task_callable(ti, **kwargs): ti.log.info("test") dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE) - dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) + dagrun = dag.create_dagrun( + run_type=DagRunType.MANUAL, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + ) task = PythonOperator( task_id='task_for_testing_file_log_handler', dag=dag, python_callable=task_callable, ) - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti = TaskInstance(task=task, run_id=dagrun.run_id) logger = ti.log ti.log.disabled = False @@ -116,13 +118,17 @@ def task_callable(ti, **kwargs): ti.log.info("test") dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE) - dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) + dagrun = dag.create_dagrun( + run_type=DagRunType.MANUAL, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + ) task = PythonOperator( task_id='task_for_testing_file_log_handler', dag=dag, python_callable=task_callable, ) - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti = TaskInstance(task=task, run_id=dagrun.run_id) logger = ti.log ti.log.disabled = False @@ -168,10 +174,16 @@ def task_callable(ti, **kwargs): dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE) task = PythonOperator( task_id='task_for_testing_file_log_handler', - dag=dag, python_callable=task_callable, + dag=dag, + ) + dagrun = dag.create_dagrun( + run_type=DagRunType.MANUAL, + state=State.RUNNING, + execution_date=DEFAULT_DATE, ) - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti = TaskInstance(task=task, run_id=dagrun.run_id) + ti.try_number = 2 ti.state = State.RUNNING @@ -206,28 +218,33 @@ def task_callable(ti, **kwargs): os.remove(log_filename) -class TestFilenameRendering(unittest.TestCase): - def setUp(self): - dag = DAG('dag_for_testing_filename_rendering', start_date=DEFAULT_DATE) - task = DummyOperator(task_id='task_for_testing_filename_rendering', dag=dag) - self.ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) +@pytest.fixture() +def filename_rendering_ti(session, create_task_instance): + return create_task_instance( + dag_id='dag_for_testing_filename_rendering', + task_id='task_for_testing_filename_rendering', + execution_date=DEFAULT_DATE, + session=session, + ) + - def test_python_formatting(self): +class TestFilenameRendering: + def test_python_formatting(self, filename_rendering_ti): expected_filename = ( 'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log' % DEFAULT_DATE.isoformat() ) fth = FileTaskHandler('', '{dag_id}/{task_id}/{execution_date}/{try_number}.log') - rendered_filename = fth._render_filename(self.ti, 42) + rendered_filename = fth._render_filename(filename_rendering_ti, 42) assert expected_filename == rendered_filename - def test_jinja_rendering(self): + def test_jinja_rendering(self, filename_rendering_ti): expected_filename = ( 'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log' % DEFAULT_DATE.isoformat() ) fth = FileTaskHandler('', '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log') - rendered_filename = fth._render_filename(self.ti, 42) + rendered_filename = fth._render_filename(filename_rendering_ti, 42) assert expected_filename == rendered_filename diff --git a/tests/utils/test_task_handler_with_custom_formatter.py b/tests/utils/test_task_handler_with_custom_formatter.py index 27724c8c38d04..e2a3c77b8f81f 100644 --- a/tests/utils/test_task_handler_with_custom_formatter.py +++ b/tests/utils/test_task_handler_with_custom_formatter.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import logging import unittest @@ -23,8 +22,11 @@ from airflow.models import DAG, TaskInstance from airflow.operators.dummy import DummyOperator from airflow.utils.log.logging_mixin import set_context +from airflow.utils.state import DagRunState from airflow.utils.timezone import datetime +from airflow.utils.types import DagRunType from tests.test_utils.config import conf_vars +from tests.test_utils.db import clear_db_runs DEFAULT_DATE = datetime(2019, 1, 1) TASK_LOGGER = 'airflow.task' @@ -35,7 +37,6 @@ class TestTaskHandlerWithCustomFormatter(unittest.TestCase): def setUp(self): - super().setUp() DEFAULT_LOGGING_CONFIG['handlers']['task'] = { 'class': TASK_HANDLER_CLASS, 'formatter': 'airflow', @@ -46,14 +47,19 @@ def setUp(self): logging.root.disabled = False def tearDown(self): - super().tearDown() + clear_db_runs() DEFAULT_LOGGING_CONFIG['handlers']['task'] = PREV_TASK_HANDLER @conf_vars({('logging', 'task_log_prefix_template'): "{{ti.dag_id}}-{{ti.task_id}}"}) def test_formatter(self): dag = DAG('test_dag', start_date=DEFAULT_DATE) task = DummyOperator(task_id='test_task', dag=dag) - ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + dagrun = dag.create_dagrun( + DagRunState.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.MANUAL, + ) + ti = TaskInstance(task=task, run_id=dagrun.run_id) logger = ti.log ti.log.disabled = False diff --git a/tests/www/views/test_views_dagrun.py b/tests/www/views/test_views_dagrun.py index a57a59528feb4..919adff937bc7 100644 --- a/tests/www/views/test_views_dagrun.py +++ b/tests/www/views/test_views_dagrun.py @@ -45,20 +45,20 @@ def reset_dagrun(): @pytest.fixture() def running_dag_run(session): dag = DagBag().get_dag("example_bash_operator") - task0 = dag.get_task("runme_0") - task1 = dag.get_task("runme_1") execution_date = timezone.datetime(2016, 1, 9) - tis = [ - TaskInstance(task0, execution_date, state="success"), - TaskInstance(task1, execution_date, state="failed"), - ] - session.bulk_save_objects(tis) dr = dag.create_dagrun( state="running", execution_date=execution_date, run_id="test_dag_runs_action", session=session, ) + session.add(dr) + tis = [ + TaskInstance(dag.get_task("runme_0"), run_id=dr.run_id, state="success"), + TaskInstance(dag.get_task("runme_1"), run_id=dr.run_id, state="failed"), + ] + session.bulk_save_objects(tis) + session.flush() return dr @@ -132,5 +132,5 @@ def test_muldelete_dag_runs_action(session, admin_client, running_dag_run): follow_redirects=True, ) assert resp.status_code == 200 - assert session.query(TaskInstance).count() == 2 # Does not delete TIs. + assert session.query(TaskInstance).count() == 0 # Deletes associated TIs. assert session.query(DagRun).filter(DagRun.id == dag_run_id).count() == 0 diff --git a/tests/www/views/test_views_extra_links.py b/tests/www/views/test_views_extra_links.py index 705e38a763aa1..8d7f1ac632dab 100644 --- a/tests/www/views/test_views_extra_links.py +++ b/tests/www/views/test_views_extra_links.py @@ -22,10 +22,13 @@ import pytest -from airflow.models import DAG, TaskInstance +from airflow.models import DAG from airflow.models.baseoperator import BaseOperator, BaseOperatorLink from airflow.utils import dates, timezone from airflow.utils.session import create_session +from airflow.utils.state import DagRunState, TaskInstanceState +from airflow.utils.types import DagRunType +from tests.test_utils.db import clear_db_runs from tests.test_utils.mock_operators import Dummy2TestOperator, Dummy3TestOperator from tests.test_utils.www import check_content_in_response @@ -76,6 +79,19 @@ def dag(): return DAG("dag", start_date=DEFAULT_DATE) +@pytest.fixture(scope="module") +def create_dag_run(dag): + def _create_dag_run(*, execution_date, session): + return dag.create_dagrun( + state=DagRunState.RUNNING, + execution_date=execution_date, + run_type=DagRunType.MANUAL, + session=session, + ) + + return _create_dag_run + + @pytest.fixture(scope="module", autouse=True) def patched_app(app, dag): with mock.patch.object(app, "dag_bag") as mock_dag_bag: @@ -104,15 +120,13 @@ def init_blank_task_instances(): This really shouldn't be needed, but tests elsewhere leave the db dirty. """ - with create_session() as session: - session.query(TaskInstance).delete() + clear_db_runs() @pytest.fixture(autouse=True) def reset_task_instances(): yield - with create_session() as session: - session.query(TaskInstance).delete() + clear_db_runs() def test_extra_links_works(dag, task_1, viewer_client): @@ -151,17 +165,19 @@ def test_global_extra_links_works(dag, task_1, viewer_client): } -def test_extra_link_in_gantt_view(dag, viewer_client): +def test_extra_link_in_gantt_view(dag, create_dag_run, viewer_client): exec_date = dates.days_ago(2) start_date = timezone.datetime(2020, 4, 10, 2, 0, 0) end_date = exec_date + datetime.timedelta(seconds=30) with create_session() as session: - for task in dag.tasks: - ti = TaskInstance(task=task, execution_date=exec_date, state="success") + dag_run = create_dag_run(execution_date=exec_date, session=session) + for ti in dag_run.task_instances: + ti.refresh_from_task(dag.get_task(ti.task_id)) + ti.state = TaskInstanceState.SUCCESS ti.start_date = start_date ti.end_date = end_date - session.add(ti) + session.merge(ti) url = f'gantt?dag_id={dag.dag_id}&execution_date={exec_date}' resp = viewer_client.get(url, follow_redirects=True) diff --git a/tests/www/views/test_views_log.py b/tests/www/views/test_views_log.py index b8cfe6e8a7df5..af3d452e415cd 100644 --- a/tests/www/views/test_views_log.py +++ b/tests/www/views/test_views_log.py @@ -27,14 +27,15 @@ from airflow import settings from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG -from airflow.models import DAG, DagBag, TaskInstance -from airflow.operators.dummy import DummyOperator +from airflow.models import DagBag from airflow.utils import timezone from airflow.utils.log.logging_mixin import ExternalLoggingMixin from airflow.utils.session import create_session -from airflow.utils.state import State +from airflow.utils.state import DagRunState, TaskInstanceState +from airflow.utils.types import DagRunType from airflow.www.app import create_app from tests.test_utils.config import conf_vars +from tests.test_utils.db import clear_db_dags, clear_db_runs from tests.test_utils.decorators import dont_initialize_flask_app_submodules from tests.test_utils.www import client_with_login @@ -102,46 +103,58 @@ def reset_modules_after_every_test(backup_modules): @pytest.fixture(autouse=True) -def dags(log_app): - dag = DAG(DAG_ID, start_date=DEFAULT_DATE) - dag_removed = DAG(DAG_ID_REMOVED, start_date=DEFAULT_DATE) +def dags(log_app, create_dummy_dag, session): + dag, _ = create_dummy_dag( + dag_id=DAG_ID, + task_id=TASK_ID, + start_date=DEFAULT_DATE, + with_dagrun=False, + session=session, + ) + dag_removed, _ = create_dummy_dag( + dag_id=DAG_ID_REMOVED, + task_id=TASK_ID, + start_date=DEFAULT_DATE, + with_dagrun=False, + session=session, + ) bag = DagBag(include_examples=False) bag.bag_dag(dag=dag, root_dag=dag) bag.bag_dag(dag=dag_removed, root_dag=dag_removed) + bag.sync_to_db(session=session) + log_app.dag_bag = bag - # Since we don't want to store the code for the DAG defined in this file - with unittest.mock.patch('airflow.models.dag.DagCode.bulk_sync_to_db'): - dag.sync_to_db() - dag_removed.sync_to_db() - bag.sync_to_db() + yield dag, dag_removed - log_app.dag_bag = bag - return dag, dag_removed + clear_db_dags() @pytest.fixture(autouse=True) -def tis(dags): +def tis(dags, session): dag, dag_removed = dags - ti = TaskInstance( - task=DummyOperator(task_id=TASK_ID, dag=dag), + dagrun = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=DagRunState.RUNNING, + session=session, ) + (ti,) = dagrun.task_instances ti.try_number = 1 - ti_removed_dag = TaskInstance( - task=DummyOperator(task_id=TASK_ID, dag=dag_removed), + dagrun_removed = dag_removed.create_dagrun( + run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=DagRunState.RUNNING, + session=session, ) + (ti_removed_dag,) = dagrun_removed.task_instances ti_removed_dag.try_number = 1 - with create_session() as session: - session.merge(ti) - session.merge(ti_removed_dag) - yield ti, ti_removed_dag - with create_session() as session: - session.query(TaskInstance).delete() + clear_db_runs() @pytest.fixture() @@ -152,13 +165,13 @@ def log_admin_client(log_app): @pytest.mark.parametrize( "state, try_number, num_logs", [ - (State.NONE, 0, 0), - (State.UP_FOR_RETRY, 2, 2), - (State.UP_FOR_RESCHEDULE, 0, 1), - (State.UP_FOR_RESCHEDULE, 1, 2), - (State.RUNNING, 1, 1), - (State.SUCCESS, 1, 1), - (State.FAILED, 3, 3), + (None, 0, 0), + (TaskInstanceState.UP_FOR_RETRY, 2, 2), + (TaskInstanceState.UP_FOR_RESCHEDULE, 0, 1), + (TaskInstanceState.UP_FOR_RESCHEDULE, 1, 2), + (TaskInstanceState.RUNNING, 1, 1), + (TaskInstanceState.SUCCESS, 1, 1), + (TaskInstanceState.FAILED, 3, 3), ], ids=[ "none", diff --git a/tests/www/views/test_views_rendered.py b/tests/www/views/test_views_rendered.py index 1f0f8ae999b5e..b749fa174faa5 100644 --- a/tests/www/views/test_views_rendered.py +++ b/tests/www/views/test_views_rendered.py @@ -20,11 +20,14 @@ import pytest -from airflow.models import DAG, RenderedTaskInstanceFields, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel +from airflow.models import DAG, RenderedTaskInstanceFields from airflow.operators.bash import BashOperator +from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils import timezone from airflow.utils.session import create_session +from airflow.utils.state import DagRunState, TaskInstanceState +from airflow.utils.types import DagRunType +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_rendered_ti_fields from tests.test_utils.www import check_content_in_response, check_content_not_in_response DEFAULT_DATE = timezone.datetime(2020, 3, 1) @@ -58,42 +61,64 @@ def task2(dag): ) +@pytest.fixture(scope="module", autouse=True) +def init_blank_db(): + """Make sure there are no runs before we test anything. + + This really shouldn't be needed, but tests elsewhere leave the db dirty. + """ + clear_db_dags() + clear_db_runs() + clear_rendered_ti_fields() + + @pytest.fixture(autouse=True) def reset_db(dag, task1, task2): - """Reset DB for each test. + yield + clear_db_dags() + clear_db_runs() + clear_rendered_ti_fields() - This writes the DAG to the DB, and clears rendered fields so we have a clean - slate for each test. Note that task1 and task2 are included in the argument - to make sure they are registered to the DAG for serialization. - The pre-test cleanup really shouldn't be necessary, but the test DB was not - initialized in a clean state to begin with :( - """ - with create_session() as session: - SerializedDagModel.write_dag(dag) - session.query(RenderedTaskInstanceFields).delete() - yield - with create_session() as session: - session.query(RenderedTaskInstanceFields).delete() - session.query(SerializedDagModel).delete() +@pytest.fixture() +def create_dag_run(dag, task1, task2): + def _create_dag_run(*, execution_date, session): + dag_run = dag.create_dagrun( + state=DagRunState.RUNNING, + execution_date=execution_date, + data_interval=(execution_date, execution_date), + run_type=DagRunType.SCHEDULED, + session=session, + ) + ti1 = dag_run.get_task_instance(task1.task_id, session=session) + ti1.state = TaskInstanceState.SUCCESS + ti2 = dag_run.get_task_instance(task2.task_id, session=session) + ti2.state = TaskInstanceState.SCHEDULED + session.flush() + return dag_run + + return _create_dag_run @pytest.fixture() def patch_app(app, dag): with mock.patch.object(app, "dag_bag") as mock_dag_bag: - mock_dag_bag.get_dag.return_value = dag + mock_dag_bag.get_dag.return_value = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) yield app @pytest.mark.usefixtures("patch_app") -def test_rendered_template_view(admin_client, task1): +def test_rendered_template_view(admin_client, create_dag_run, task1): """ Test that the Rendered View contains the values from RenderedTaskInstanceFields """ assert task1.bash_command == '{{ task_instance_key_str }}' - ti = TaskInstance(task1, DEFAULT_DATE) with create_session() as session: + dag_run = create_dag_run(execution_date=DEFAULT_DATE, session=session) + ti = dag_run.get_task_instance(task1.task_id, session=session) + assert ti is not None, "task instance not found" + ti.refresh_from_task(task1) session.add(RenderedTaskInstanceFields(ti)) url = f'rendered-templates?task_id=task1&dag_id=testdag&execution_date={quote_plus(str(DEFAULT_DATE))}' @@ -103,43 +128,39 @@ def test_rendered_template_view(admin_client, task1): @pytest.mark.usefixtures("patch_app") -def test_rendered_template_view_for_unexecuted_tis(admin_client, task1): +def test_rendered_template_view_for_unexecuted_tis(admin_client, create_dag_run, task1): """ Test that the Rendered View is able to show rendered values even for TIs that have not yet executed """ assert task1.bash_command == '{{ task_instance_key_str }}' - url = f'rendered-templates?task_id=task1&dag_id=task1&execution_date={quote_plus(str(DEFAULT_DATE))}' + with create_session() as session: + create_dag_run(execution_date=DEFAULT_DATE, session=session) + + url = f'rendered-templates?task_id=task1&dag_id=testdag&execution_date={quote_plus(str(DEFAULT_DATE))}' resp = admin_client.get(url, follow_redirects=True) check_content_in_response("testdag__task1__20200301", resp) -def test_user_defined_filter_and_macros_raise_error(app, admin_client, dag, task2): - """ - Test that the Rendered View is able to show rendered values - even for TIs that have not yet executed - """ - dag = SerializedDagModel.get(dag.dag_id).dag - with mock.patch.object(app, "dag_bag") as mock_dag_bag: - mock_dag_bag.get_dag.return_value = dag +@pytest.mark.usefixtures("patch_app") +def test_user_defined_filter_and_macros_raise_error(admin_client, create_dag_run, task2): + assert task2.bash_command == 'echo {{ fullname("Apache", "Airflow") | hello }}' - assert task2.bash_command == 'echo {{ fullname("Apache", "Airflow") | hello }}' + with create_session() as session: + create_dag_run(execution_date=DEFAULT_DATE, session=session) - url = ( - f'rendered-templates?task_id=task2&dag_id=testdag&' - f'execution_date={quote_plus(str(DEFAULT_DATE))}' - ) + url = f'rendered-templates?task_id=task2&dag_id=testdag&execution_date={quote_plus(str(DEFAULT_DATE))}' - resp = admin_client.get(url, follow_redirects=True) + resp = admin_client.get(url, follow_redirects=True) - check_content_not_in_response("echo Hello Apache Airflow", resp) - check_content_in_response( - "Webserver does not have access to User-defined Macros or Filters when " - "Dag Serialization is enabled. Hence for the task that have not yet " - "started running, please use 'airflow tasks render' for " - "debugging the rendering of template_fields.

OriginalError: no " - "filter named 'hello'", - resp, - ) + check_content_not_in_response("echo Hello Apache Airflow", resp) + check_content_in_response( + "Webserver does not have access to User-defined Macros or Filters when " + "Dag Serialization is enabled. Hence for the task that have not yet " + "started running, please use 'airflow tasks render' for " + "debugging the rendering of template_fields.

OriginalError: no " + "filter named 'hello'", + resp, + )