From 825df6db600fd3ba81f708e60c12be8e2b515a2f Mon Sep 17 00:00:00 2001 From: Pierre Jeambrun Date: Tue, 28 Jan 2025 11:22:37 +0100 Subject: [PATCH] AIP-84 Add patch task_instance dry_run endpoint (#46018) --- .../core_api/datamodels/task_instances.py | 3 +- .../core_api/openapi/v1-generated.yaml | 194 +++++++- .../core_api/routes/public/task_instances.py | 145 ++++-- airflow/ui/openapi-gen/queries/common.ts | 6 + airflow/ui/openapi-gen/queries/queries.ts | 122 ++++- .../ui/openapi-gen/requests/schemas.gen.ts | 8 +- .../ui/openapi-gen/requests/services.gen.ts | 90 +++- airflow/ui/openapi-gen/requests/types.gen.ts | 95 +++- .../routes/public/test_task_instances.py | 443 ++++++++++++++---- 9 files changed, 967 insertions(+), 139 deletions(-) diff --git a/airflow/api_fastapi/core_api/datamodels/task_instances.py b/airflow/api_fastapi/core_api/datamodels/task_instances.py index 4754e67f2d3da..7cecb96ca42ce 100644 --- a/airflow/api_fastapi/core_api/datamodels/task_instances.py +++ b/airflow/api_fastapi/core_api/datamodels/task_instances.py @@ -198,8 +198,7 @@ def validate_model(cls, data: Any) -> Any: class PatchTaskInstanceBody(BaseModel): """Request body for Clear Task Instances endpoint.""" - dry_run: bool = True - new_state: str | None = None + new_state: TaskInstanceState | None = None note: Annotated[str, StringConstraints(max_length=1000)] | None = None include_upstream: bool = False include_downstream: bool = False diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index b40538e02444c..62df27b568ecd 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -4518,7 +4518,7 @@ paths: tags: - Task Instance summary: Patch Task Instance - description: Update the state of a task instance. + description: Update a task instance. operationId: patch_task_instance parameters: - name: dag_id @@ -5125,7 +5125,7 @@ paths: tags: - Task Instance summary: Patch Task Instance - description: Update the state of a task instance. + description: Update a task instance. operationId: patch_task_instance parameters: - name: dag_id @@ -5675,6 +5675,189 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}/dry_run: + patch: + tags: + - Task Instance + summary: Patch Task Instance Dry Run + description: Update a task instance dry_run mode. + operationId: patch_task_instance_dry_run + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + - name: map_index + in: path + required: true + schema: + type: integer + title: Map Index + - name: update_mask + in: query + required: false + schema: + anyOf: + - type: array + items: + type: string + - type: 'null' + title: Update Mask + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/PatchTaskInstanceBody' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/TaskInstanceCollectionResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '409': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Conflict + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/dry_run: + patch: + tags: + - Task Instance + summary: Patch Task Instance Dry Run + description: Update a task instance dry_run mode. + operationId: patch_task_instance_dry_run + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + - name: map_index + in: query + required: false + schema: + type: integer + default: -1 + title: Map Index + - name: update_mask + in: query + required: false + schema: + anyOf: + - type: array + items: + type: string + - type: 'null' + title: Update Mask + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/PatchTaskInstanceBody' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/TaskInstanceCollectionResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '409': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Conflict + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/dags/{dag_id}/tasks: get: tags: @@ -8931,15 +9114,10 @@ components: description: Node serializer for responses. PatchTaskInstanceBody: properties: - dry_run: - type: boolean - title: Dry Run - default: true new_state: anyOf: - - type: string + - $ref: '#/components/schemas/TaskInstanceState' - type: 'null' - title: New State note: anyOf: - type: string diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index 1ecc18a6513fb..c97e190fb2cd8 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -62,6 +62,7 @@ from airflow.exceptions import TaskNotFound from airflow.jobs.scheduler_job_runner import DR from airflow.models import Base, DagRun +from airflow.models.dag import DAG from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances from airflow.models.taskinstancehistory import TaskInstanceHistory as TIH from airflow.ti_deps.dep_context import DepContext @@ -661,19 +662,7 @@ def post_clear_task_instances( ) -@task_instances_router.patch( - task_instances_prefix + "/{task_id}", - responses=create_openapi_http_exception_doc( - [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, status.HTTP_409_CONFLICT], - ), -) -@task_instances_router.patch( - task_instances_prefix + "/{task_id}/{map_index}", - responses=create_openapi_http_exception_doc( - [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, status.HTTP_409_CONFLICT], - ), -) -def patch_task_instance( +def _patch_ti_validate_request( dag_id: str, dag_run_id: str, task_id: str, @@ -682,8 +671,7 @@ def patch_task_instance( session: SessionDep, map_index: int = -1, update_mask: list[str] | None = Query(None), -) -> TaskInstanceResponse: - """Update the state of a task instance.""" +) -> tuple[DAG, TI, dict]: dag = request.app.state.dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG {dag_id} not found") @@ -717,34 +705,123 @@ def patch_task_instance( fields_to_update = body.model_fields_set if update_mask: fields_to_update = fields_to_update.intersection(update_mask) - data = body.model_dump(include=fields_to_update, by_alias=True) else: try: PatchTaskInstanceBody.model_validate(body) except ValidationError as e: raise RequestValidationError(errors=e.errors()) - data = body.model_dump(by_alias=True) + + return dag, ti, body.model_dump(include=fields_to_update, by_alias=True) + + +@task_instances_router.patch( + task_instances_prefix + "/{task_id}/dry_run", + responses=create_openapi_http_exception_doc( + [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, status.HTTP_409_CONFLICT], + ), +) +@task_instances_router.patch( + task_instances_prefix + "/{task_id}/{map_index}/dry_run", + responses=create_openapi_http_exception_doc( + [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, status.HTTP_409_CONFLICT], + ), +) +def patch_task_instance_dry_run( + dag_id: str, + dag_run_id: str, + task_id: str, + request: Request, + body: PatchTaskInstanceBody, + session: SessionDep, + map_index: int = -1, + update_mask: list[str] | None = Query(None), +) -> TaskInstanceCollectionResponse: + """Update a task instance dry_run mode.""" + dag, ti, data = _patch_ti_validate_request( + dag_id, dag_run_id, task_id, request, body, session, map_index, update_mask + ) + + tis: list[TI] = [] + + if data.get("new_state"): + tis = dag.set_task_instance_state( + task_id=task_id, + run_id=dag_run_id, + map_indexes=[map_index], + state=data["new_state"], + upstream=body.include_upstream, + downstream=body.include_downstream, + future=body.include_future, + past=body.include_past, + commit=False, + session=session, + ) + + if not tis: + raise HTTPException( + status.HTTP_409_CONFLICT, f"Task id {task_id} is already in {data['new_state']} state" + ) + elif "note" in data: + tis = [ti] + + return TaskInstanceCollectionResponse( + task_instances=[ + TaskInstanceResponse.model_validate( + ti, + from_attributes=True, + ) + for ti in tis + ], + total_entries=len(tis), + ) + + +@task_instances_router.patch( + task_instances_prefix + "/{task_id}", + responses=create_openapi_http_exception_doc( + [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, status.HTTP_409_CONFLICT], + ), +) +@task_instances_router.patch( + task_instances_prefix + "/{task_id}/{map_index}", + responses=create_openapi_http_exception_doc( + [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, status.HTTP_409_CONFLICT], + ), +) +def patch_task_instance( + dag_id: str, + dag_run_id: str, + task_id: str, + request: Request, + body: PatchTaskInstanceBody, + session: SessionDep, + map_index: int = -1, + update_mask: list[str] | None = Query(None), +) -> TaskInstanceResponse: + """Update a task instance.""" + dag, ti, data = _patch_ti_validate_request( + dag_id, dag_run_id, task_id, request, body, session, map_index, update_mask + ) for key, _ in data.items(): if key == "new_state": - if not body.dry_run: - tis: list[TI] = dag.set_task_instance_state( - task_id=task_id, - run_id=dag_run_id, - map_indexes=[map_index], - state=body.new_state, - upstream=body.include_upstream, - downstream=body.include_downstream, - future=body.include_future, - past=body.include_past, - commit=True, - session=session, + tis: list[TI] = dag.set_task_instance_state( + task_id=task_id, + run_id=dag_run_id, + map_indexes=[map_index], + state=data["new_state"], + upstream=body.include_upstream, + downstream=body.include_downstream, + future=body.include_future, + past=body.include_past, + commit=True, + session=session, + ) + if not tis: + raise HTTPException( + status.HTTP_409_CONFLICT, f"Task id {task_id} is already in {data['new_state']} state" ) - if not tis: - raise HTTPException( - status.HTTP_409_CONFLICT, f"Task id {task_id} is already in {data['new_state']} state" - ) - ti = tis[0] if isinstance(tis, list) else tis + ti = tis[0] if isinstance(tis, list) else tis elif key == "note": if update_mask or body.note is not None: # @TODO: replace None passed for user_id with actual user id when diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 73236c9b2d6c0..0e137620a4e6d 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1690,6 +1690,12 @@ export type TaskInstanceServicePatchTaskInstanceMutationResult = Awaited< export type TaskInstanceServicePatchTaskInstance1MutationResult = Awaited< ReturnType >; +export type TaskInstanceServicePatchTaskInstanceDryRunMutationResult = Awaited< + ReturnType +>; +export type TaskInstanceServicePatchTaskInstanceDryRun1MutationResult = Awaited< + ReturnType +>; export type PoolServicePatchPoolMutationResult = Awaited>; export type PoolServiceBulkPoolsMutationResult = Awaited>; export type VariableServicePatchVariableMutationResult = Awaited< diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 0fe9e425b703c..c4d201767e17d 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -3596,7 +3596,7 @@ export const useDagServicePatchDag = < }); /** * Patch Task Instance - * Update the state of a task instance. + * Update a task instance. * @param data The data for the request. * @param data.dagId * @param data.dagRunId @@ -3655,7 +3655,7 @@ export const useTaskInstanceServicePatchTaskInstance = < }); /** * Patch Task Instance - * Update the state of a task instance. + * Update a task instance. * @param data The data for the request. * @param data.dagId * @param data.dagRunId @@ -3712,6 +3712,124 @@ export const useTaskInstanceServicePatchTaskInstance1 = < }) as unknown as Promise, ...options, }); +/** + * Patch Task Instance Dry Run + * Update a task instance dry_run mode. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.mapIndex + * @param data.requestBody + * @param data.updateMask + * @returns TaskInstanceCollectionResponse Successful Response + * @throws ApiError + */ +export const useTaskInstanceServicePatchTaskInstanceDryRun = < + TData = Common.TaskInstanceServicePatchTaskInstanceDryRunMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + dagId: string; + dagRunId: string; + mapIndex: number; + requestBody: PatchTaskInstanceBody; + taskId: string; + updateMask?: string[]; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + dagId: string; + dagRunId: string; + mapIndex: number; + requestBody: PatchTaskInstanceBody; + taskId: string; + updateMask?: string[]; + }, + TContext + >({ + mutationFn: ({ dagId, dagRunId, mapIndex, requestBody, taskId, updateMask }) => + TaskInstanceService.patchTaskInstanceDryRun({ + dagId, + dagRunId, + mapIndex, + requestBody, + taskId, + updateMask, + }) as unknown as Promise, + ...options, + }); +/** + * Patch Task Instance Dry Run + * Update a task instance dry_run mode. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.requestBody + * @param data.mapIndex + * @param data.updateMask + * @returns TaskInstanceCollectionResponse Successful Response + * @throws ApiError + */ +export const useTaskInstanceServicePatchTaskInstanceDryRun1 = < + TData = Common.TaskInstanceServicePatchTaskInstanceDryRun1MutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + dagId: string; + dagRunId: string; + mapIndex?: number; + requestBody: PatchTaskInstanceBody; + taskId: string; + updateMask?: string[]; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + dagId: string; + dagRunId: string; + mapIndex?: number; + requestBody: PatchTaskInstanceBody; + taskId: string; + updateMask?: string[]; + }, + TContext + >({ + mutationFn: ({ dagId, dagRunId, mapIndex, requestBody, taskId, updateMask }) => + TaskInstanceService.patchTaskInstanceDryRun1({ + dagId, + dagRunId, + mapIndex, + requestBody, + taskId, + updateMask, + }) as unknown as Promise, + ...options, + }); /** * Patch Pool * Update a Pool. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 3def7c5825cbe..170b977beb0ed 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -3913,21 +3913,15 @@ export const $NodeResponse = { export const $PatchTaskInstanceBody = { properties: { - dry_run: { - type: "boolean", - title: "Dry Run", - default: true, - }, new_state: { anyOf: [ { - type: "string", + $ref: "#/components/schemas/TaskInstanceState", }, { type: "null", }, ], - title: "New State", }, note: { anyOf: [ diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 91edabfc3dc72..bf6e528da11ed 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -145,6 +145,10 @@ import type { GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, + PatchTaskInstanceDryRunData, + PatchTaskInstanceDryRunResponse, + PatchTaskInstanceDryRun1Data, + PatchTaskInstanceDryRun1Response, GetLogData, GetLogResponse, GetImportErrorData, @@ -1981,7 +1985,7 @@ export class TaskInstanceService { /** * Patch Task Instance - * Update the state of a task instance. + * Update a task instance. * @param data The data for the request. * @param data.dagId * @param data.dagRunId @@ -2249,7 +2253,7 @@ export class TaskInstanceService { /** * Patch Task Instance - * Update the state of a task instance. + * Update a task instance. * @param data The data for the request. * @param data.dagId * @param data.dagRunId @@ -2486,6 +2490,88 @@ export class TaskInstanceService { }); } + /** + * Patch Task Instance Dry Run + * Update a task instance dry_run mode. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.mapIndex + * @param data.requestBody + * @param data.updateMask + * @returns TaskInstanceCollectionResponse Successful Response + * @throws ApiError + */ + public static patchTaskInstanceDryRun( + data: PatchTaskInstanceDryRunData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "PATCH", + url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}/dry_run", + path: { + dag_id: data.dagId, + dag_run_id: data.dagRunId, + task_id: data.taskId, + map_index: data.mapIndex, + }, + query: { + update_mask: data.updateMask, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 409: "Conflict", + 422: "Validation Error", + }, + }); + } + + /** + * Patch Task Instance Dry Run + * Update a task instance dry_run mode. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.requestBody + * @param data.mapIndex + * @param data.updateMask + * @returns TaskInstanceCollectionResponse Successful Response + * @throws ApiError + */ + public static patchTaskInstanceDryRun1( + data: PatchTaskInstanceDryRun1Data, + ): CancelablePromise { + return __request(OpenAPI, { + method: "PATCH", + url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/dry_run", + path: { + dag_id: data.dagId, + dag_run_id: data.dagRunId, + task_id: data.taskId, + }, + query: { + map_index: data.mapIndex, + update_mask: data.updateMask, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 409: "Conflict", + 422: "Validation Error", + }, + }); + } + /** * Get Log * Get logs for a specific task instance. diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index cf1cacdf1158a..6ff1a083ec6e8 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -1024,8 +1024,7 @@ export type type = * Request body for Clear Task Instances endpoint. */ export type PatchTaskInstanceBody = { - dry_run?: boolean; - new_state?: string | null; + new_state?: TaskInstanceState | null; note?: string | null; include_upstream?: boolean; include_downstream?: boolean; @@ -2159,6 +2158,28 @@ export type PostClearTaskInstancesData = { export type PostClearTaskInstancesResponse = TaskInstanceCollectionResponse; +export type PatchTaskInstanceDryRunData = { + dagId: string; + dagRunId: string; + mapIndex: number; + requestBody: PatchTaskInstanceBody; + taskId: string; + updateMask?: Array | null; +}; + +export type PatchTaskInstanceDryRunResponse = TaskInstanceCollectionResponse; + +export type PatchTaskInstanceDryRun1Data = { + dagId: string; + dagRunId: string; + mapIndex?: number; + requestBody: PatchTaskInstanceBody; + taskId: string; + updateMask?: Array | null; +}; + +export type PatchTaskInstanceDryRun1Response = TaskInstanceCollectionResponse; + export type GetLogData = { accept?: "application/json" | "text/plain" | "*/*"; dagId: string; @@ -4261,6 +4282,76 @@ export type $OpenApiTs = { }; }; }; + "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}/dry_run": { + patch: { + req: PatchTaskInstanceDryRunData; + res: { + /** + * Successful Response + */ + 200: TaskInstanceCollectionResponse; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Conflict + */ + 409: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; + "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/dry_run": { + patch: { + req: PatchTaskInstanceDryRun1Data; + res: { + /** + * Successful Response + */ + 200: TaskInstanceCollectionResponse; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Conflict + */ + 409: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/logs/{try_number}": { get: { req: GetLogData; diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index bc2cef7c03e03..ebf718340d3d6 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -2693,7 +2693,6 @@ def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): response = test_client.patch( self.ENDPOINT_URL, json={ - "dry_run": False, "new_state": self.NEW_STATE, }, ) @@ -2743,68 +2742,12 @@ def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): task_id=self.TASK_ID, ) - @mock.patch("airflow.models.dag.DAG.set_task_instance_state") - def test_should_not_call_mocked_api_for_dry_run(self, mock_set_task_instance_state, test_client, session): - self.create_task_instances(session) - - mock_set_task_instance_state.return_value = session.scalars( - select(TaskInstance).where( - TaskInstance.dag_id == self.DAG_ID, - TaskInstance.task_id == self.TASK_ID, - TaskInstance.run_id == self.RUN_ID, - TaskInstance.map_index == -1, - ) - ).one_or_none() - - response = test_client.patch( - self.ENDPOINT_URL, - json={ - "dry_run": True, - "new_state": self.NEW_STATE, - }, - ) - assert response.status_code == 200 - assert response.json() == { - "dag_id": self.DAG_ID, - "dag_run_id": self.RUN_ID, - "logical_date": "2020-01-01T00:00:00Z", - "task_id": self.TASK_ID, - "duration": 10000.0, - "end_date": "2020-01-03T00:00:00Z", - "executor": None, - "executor_config": "{}", - "hostname": "", - "id": mock.ANY, - "map_index": -1, - "max_tries": 0, - "note": "placeholder-note", - "operator": "PythonOperator", - "pid": 100, - "pool": "default_pool", - "pool_slots": 1, - "priority_weight": 9, - "queue": "default_queue", - "queued_when": None, - "start_date": "2020-01-02T00:00:00Z", - "state": "running", - "task_display_name": self.TASK_ID, - "try_number": 0, - "unixname": getuser(), - "rendered_fields": {}, - "rendered_map_index": None, - "trigger": None, - "triggerer_job": None, - } - - mock_set_task_instance_state.assert_not_called() - def test_should_update_task_instance_state(self, test_client, session): self.create_task_instances(session) test_client.patch( self.ENDPOINT_URL, json={ - "dry_run": False, "new_state": self.NEW_STATE, }, ) @@ -2813,20 +2756,6 @@ def test_should_update_task_instance_state(self, test_client, session): assert response2.status_code == 200 assert response2.json()["state"] == self.NEW_STATE - def test_should_update_task_instance_state_default_dry_run_to_true(self, test_client, session): - self.create_task_instances(session) - - test_client.patch( - self.ENDPOINT_URL, - json={ - "new_state": self.NEW_STATE, - }, - ) - - response2 = test_client.get(self.ENDPOINT_URL) - assert response2.status_code == 200 - assert response2.json()["state"] == "running" # no change in state - def test_should_update_mapped_task_instance_state(self, test_client, session): map_index = 1 tis = self.create_task_instances(session) @@ -2838,7 +2767,6 @@ def test_should_update_mapped_task_instance_state(self, test_client, session): response = test_client.patch( f"{self.ENDPOINT_URL}/{map_index}", json={ - "dry_run": False, "new_state": self.NEW_STATE, }, ) @@ -2858,7 +2786,6 @@ def test_should_update_mapped_task_instance_state(self, test_client, session): ), 404, { - "dry_run": True, "new_state": "failed", }, ] @@ -2877,7 +2804,6 @@ def test_should_200_for_unknown_fields(self, test_client, session): response = test_client.patch( self.ENDPOINT_URL, json={ - "dryrun": True, "new_state": self.NEW_STATE, }, ) @@ -2887,7 +2813,6 @@ def test_should_raise_404_for_non_existent_dag(self, test_client): response = test_client.patch( "/public/dags/non-existent-dag/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", json={ - "dry_run": False, "new_state": self.NEW_STATE, }, ) @@ -2898,7 +2823,6 @@ def test_should_raise_404_for_non_existent_task_in_dag(self, test_client): response = test_client.patch( "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/non_existent_task", json={ - "dry_run": False, "new_state": self.NEW_STATE, }, ) @@ -2911,7 +2835,6 @@ def test_should_raise_404_not_found_dag(self, test_client): response = test_client.patch( self.ENDPOINT_URL, json={ - "dry_run": True, "new_state": self.NEW_STATE, }, ) @@ -2921,7 +2844,6 @@ def test_should_raise_404_not_found_task(self, test_client): response = test_client.patch( self.ENDPOINT_URL, json={ - "dry_run": True, "new_state": self.NEW_STATE, }, ) @@ -2932,14 +2854,12 @@ def test_should_raise_404_not_found_task(self, test_client): [ ( { - "dry_run": True, "new_state": "failede", }, f"'failede' is not one of ['{State.SUCCESS}', '{State.FAILED}', '{State.SKIPPED}']", ), ( { - "dry_run": True, "new_state": "queued", }, f"'queued' is not one of ['{State.SUCCESS}', '{State.FAILED}', '{State.SKIPPED}']", @@ -3048,7 +2968,6 @@ def test_update_mask_should_call_mocked_api( self.ENDPOINT_URL, params={"update_mask": "new_state"}, json={ - "dry_run": False, "new_state": new_state, }, ) @@ -3222,7 +3141,367 @@ def test_should_raise_409_for_updating_same_task_instance_state( response = test_client.patch( self.ENDPOINT_URL, json={ - "dry_run": False, + "new_state": "success", + }, + ) + assert response.status_code == 409 + assert "Task id print_the_context is already in success state" in response.text + + +class TestPatchTaskInstanceDryRun(TestTaskInstanceEndpoint): + ENDPOINT_URL = ( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context" + ) + NEW_STATE = "failed" + DAG_ID = "example_python_operator" + TASK_ID = "print_the_context" + RUN_ID = "TEST_DAG_RUN_ID" + + @mock.patch("airflow.models.dag.DAG.set_task_instance_state") + def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): + self.create_task_instances(session) + + mock_set_ti_state.return_value = [ + session.scalars( + select(TaskInstance).where( + TaskInstance.dag_id == self.DAG_ID, + TaskInstance.task_id == self.TASK_ID, + TaskInstance.run_id == self.RUN_ID, + TaskInstance.map_index == -1, + ) + ).one_or_none() + ] + + response = test_client.patch( + f"{self.ENDPOINT_URL}/dry_run", + json={ + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 200 + assert response.json() == { + "task_instances": [ + { + "dag_id": self.DAG_ID, + "dag_run_id": self.RUN_ID, + "logical_date": "2020-01-01T00:00:00Z", + "task_id": self.TASK_ID, + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "executor": None, + "executor_config": "{}", + "hostname": "", + "id": mock.ANY, + "map_index": -1, + "max_tries": 0, + "note": "placeholder-note", + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_display_name": self.TASK_ID, + "try_number": 0, + "unixname": getuser(), + "rendered_fields": {}, + "rendered_map_index": None, + "trigger": None, + "triggerer_job": None, + } + ], + "total_entries": 1, + } + + mock_set_ti_state.assert_called_once_with( + commit=False, + downstream=False, + upstream=False, + future=False, + map_indexes=[-1], + past=False, + run_id=self.RUN_ID, + session=mock.ANY, + state=self.NEW_STATE, + task_id=self.TASK_ID, + ) + + @pytest.mark.parametrize( + "payload", + [ + { + "new_state": "success", + }, + { + "note": "something", + }, + { + "new_state": "success", + "note": "something", + }, + ], + ) + def test_should_not_update(self, test_client, session, payload): + self.create_task_instances(session) + + task_before = test_client.get(self.ENDPOINT_URL).json() + + response = test_client.patch( + f"{self.ENDPOINT_URL}/dry_run", + json=payload, + ) + + assert response.status_code == 200 + assert [ti["task_id"] for ti in response.json()["task_instances"]] == ["print_the_context"] + + task_after = test_client.get(self.ENDPOINT_URL).json() + + assert task_before == task_after + + def test_should_not_update_mapped_task_instance(self, test_client, session): + map_index = 1 + tis = self.create_task_instances(session) + ti = TaskInstance(task=tis[0].task, run_id=tis[0].run_id, map_index=map_index) + ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) + session.add(ti) + session.commit() + + task_before = test_client.get(f"{self.ENDPOINT_URL}/{map_index}").json() + + response = test_client.patch( + f"{self.ENDPOINT_URL}/{map_index}/dry_run", + json={ + "new_state": self.NEW_STATE, + }, + ) + + assert response.status_code == 200 + assert [ti["task_id"] for ti in response.json()["task_instances"]] == ["print_the_context"] + + task_after = test_client.get(f"{self.ENDPOINT_URL}/{map_index}").json() + + assert task_before == task_after + + @pytest.mark.parametrize( + "error, code, payload", + [ + [ + ( + "Task Instance not found for dag_id=example_python_operator" + ", run_id=TEST_DAG_RUN_ID, task_id=print_the_context" + ), + 404, + { + "new_state": "failed", + }, + ] + ], + ) + def test_should_handle_errors(self, error, code, payload, test_client, session): + response = test_client.patch( + f"{self.ENDPOINT_URL}/dry_run", + json=payload, + ) + assert response.status_code == code + assert response.json()["detail"] == error + + def test_should_200_for_unknown_fields(self, test_client, session): + self.create_task_instances(session) + response = test_client.patch( + f"{self.ENDPOINT_URL}/dry_run", + json={ + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 200 + + def test_should_raise_404_for_non_existent_dag(self, test_client): + response = test_client.patch( + "/public/dags/non-existent-dag/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/dry_run", + json={ + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 404 + assert response.json() == {"detail": "DAG non-existent-dag not found"} + + def test_should_raise_404_for_non_existent_task_in_dag(self, test_client): + response = test_client.patch( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/non_existent_task/dry_run", + json={ + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 404 + assert response.json() == { + "detail": "Task 'non_existent_task' not found in DAG 'example_python_operator'" + } + + def test_should_raise_404_not_found_dag(self, test_client): + response = test_client.patch( + f"{self.ENDPOINT_URL}/dry_run", + json={ + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 404 + + def test_should_raise_404_not_found_task(self, test_client): + response = test_client.patch( + f"{self.ENDPOINT_URL}/dry_run", + json={ + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 404 + + @pytest.mark.parametrize( + "payload, expected", + [ + ( + { + "new_state": "failede", + }, + f"'failede' is not one of ['{State.SUCCESS}', '{State.FAILED}', '{State.SKIPPED}']", + ), + ( + { + "new_state": "queued", + }, + f"'queued' is not one of ['{State.SUCCESS}', '{State.FAILED}', '{State.SKIPPED}']", + ), + ], + ) + def test_should_raise_422_for_invalid_task_instance_state(self, payload, expected, test_client, session): + self.create_task_instances(session) + response = test_client.patch( + f"{self.ENDPOINT_URL}/dry_run", + json=payload, + ) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "value_error", + "loc": ["body", "new_state"], + "msg": f"Value error, {expected}", + "input": payload["new_state"], + "ctx": {"error": {}}, + } + ] + } + + @pytest.mark.parametrize( + "new_state,expected_status_code,expected_json,set_ti_state_call_count", + [ + ( + "failed", + 200, + { + "task_instances": [ + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID", + "logical_date": "2020-01-01T00:00:00Z", + "task_id": "print_the_context", + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "executor": None, + "executor_config": "{}", + "hostname": "", + "id": mock.ANY, + "map_index": -1, + "max_tries": 0, + "note": "placeholder-note", + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_display_name": "print_the_context", + "try_number": 0, + "unixname": getuser(), + "rendered_fields": {}, + "rendered_map_index": None, + "trigger": None, + "triggerer_job": None, + } + ], + "total_entries": 1, + }, + 1, + ), + ( + None, + 422, + { + "detail": [ + { + "type": "value_error", + "loc": ["body", "new_state"], + "msg": "Value error, 'new_state' should not be empty", + "input": None, + "ctx": {"error": {}}, + } + ] + }, + 0, + ), + ], + ) + @mock.patch("airflow.models.dag.DAG.set_task_instance_state") + def test_update_mask_should_call_mocked_api( + self, + mock_set_ti_state, + test_client, + session, + new_state, + expected_status_code, + expected_json, + set_ti_state_call_count, + ): + self.create_task_instances(session) + + mock_set_ti_state.return_value = [ + session.scalars( + select(TaskInstance).where( + TaskInstance.dag_id == self.DAG_ID, + TaskInstance.task_id == self.TASK_ID, + TaskInstance.run_id == self.RUN_ID, + TaskInstance.map_index == -1, + ) + ).one_or_none() + ] + + response = test_client.patch( + f"{self.ENDPOINT_URL}/dry_run", + params={"update_mask": "new_state"}, + json={ + "new_state": new_state, + }, + ) + assert response.status_code == expected_status_code + assert response.json() == expected_json + assert mock_set_ti_state.call_count == set_ti_state_call_count + + @mock.patch("airflow.models.dag.DAG.set_task_instance_state") + def test_should_raise_409_for_updating_same_task_instance_state( + self, mock_set_ti_state, test_client, session + ): + self.create_task_instances(session) + + mock_set_ti_state.return_value = None + + response = test_client.patch( + f"{self.ENDPOINT_URL}/dry_run", + json={ "new_state": "success", }, )