This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a1fbdb3c1eb AIP-84 Migrate Trigger Dag Run endpoint to FastAPI
(#43875)
a1fbdb3c1eb is described below
commit a1fbdb3c1ebe1ef7964dfbf71d7833ff2451482d
Author: Kalyan R <[email protected]>
AuthorDate: Wed Nov 27 05:32:48 2024 +0530
AIP-84 Migrate Trigger Dag Run endpoint to FastAPI (#43875)
* init
* wip
* remove logical_date
* fix trigger dag_run
* tests WIP
* working tests
* remove logical_date from post body
* remove logical_date from tests
* fix
* include return type
* fix conf
* feedback
* fix tests
* Update tests/api_fastapi/core_api/routes/public/test_dag_run.py
* feedback
---
.../api_connexion/endpoints/dag_run_endpoint.py | 1 +
airflow/api_fastapi/core_api/datamodels/dag_run.py | 35 ++-
.../api_fastapi/core_api/openapi/v1-generated.yaml | 91 ++++++
.../api_fastapi/core_api/routes/public/dag_run.py | 70 ++++-
airflow/ui/openapi-gen/queries/common.ts | 3 +
airflow/ui/openapi-gen/queries/queries.ts | 44 +++
airflow/ui/openapi-gen/requests/schemas.gen.ts | 58 ++++
airflow/ui/openapi-gen/requests/services.gen.ts | 33 +++
airflow/ui/openapi-gen/requests/types.gen.ts | 53 ++++
.../core_api/routes/public/test_dag_run.py | 327 ++++++++++++++++++++-
10 files changed, 704 insertions(+), 11 deletions(-)
diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py
b/airflow/api_connexion/endpoints/dag_run_endpoint.py
index dadfb3e4f42..985efc7fc89 100644
--- a/airflow/api_connexion/endpoints/dag_run_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py
@@ -305,6 +305,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION)
-> APIResponse:
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs,
total_entries=total_entries))
+@mark_fastapi_migration_done
@security.requires_access_dag("POST", DagAccessEntity.RUN)
@action_logging
@provide_session
diff --git a/airflow/api_fastapi/core_api/datamodels/dag_run.py
b/airflow/api_fastapi/core_api/datamodels/dag_run.py
index 55240d15e55..ab812627787 100644
--- a/airflow/api_fastapi/core_api/datamodels/dag_run.py
+++ b/airflow/api_fastapi/core_api/datamodels/dag_run.py
@@ -20,9 +20,11 @@ from __future__ import annotations
from datetime import datetime
from enum import Enum
-from pydantic import AwareDatetime, Field, NonNegativeInt
+from pydantic import AwareDatetime, Field, NonNegativeInt, computed_field,
model_validator
from airflow.api_fastapi.core_api.base import BaseModel
+from airflow.models import DagRun
+from airflow.utils import timezone
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -75,6 +77,37 @@ class DAGRunCollectionResponse(BaseModel):
total_entries: int
+class TriggerDAGRunPostBody(BaseModel):
+ """Trigger DAG Run Serializer for POST body."""
+
+ dag_run_id: str | None = None
+ data_interval_start: AwareDatetime | None = None
+ data_interval_end: AwareDatetime | None = None
+
+ conf: dict = Field(default_factory=dict)
+ note: str | None = None
+
+ @model_validator(mode="after")
+ def check_data_intervals(cls, values):
+ if (values.data_interval_start is None) != (values.data_interval_end
is None):
+ raise ValueError(
+ "Either both data_interval_start and data_interval_end must be
provided or both must be None"
+ )
+ return values
+
+ @model_validator(mode="after")
+ def validate_dag_run_id(self):
+ if not self.dag_run_id:
+ self.dag_run_id = DagRun.generate_run_id(DagRunType.MANUAL,
self.logical_date)
+ return self
+
+ # Mypy issue https://github.com/python/mypy/issues/1362
+ @computed_field # type: ignore[misc]
+ @property
+ def logical_date(self) -> datetime:
+ return timezone.utcnow()
+
+
class DAGRunsBatchBody(BaseModel):
"""List DAG Runs body for batch endpoint."""
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index c53f68a8438..46fd0382eb9 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -1828,6 +1828,67 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
+ post:
+ tags:
+ - DagRun
+ summary: Trigger Dag Run
+ description: Trigger a DAG.
+ operationId: trigger_dag_run
+ parameters:
+ - name: dag_id
+ in: path
+ required: true
+ schema:
+ title: Dag Id
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/TriggerDAGRunPostBody'
+ responses:
+ '200':
+ description: Successful Response
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/DAGRunResponse'
+ '401':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Unauthorized
+ '403':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Forbidden
+ '400':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Bad Request
+ '404':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Not Found
+ '409':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Conflict
+ '422':
+ description: Validation Error
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
/public/dags/{dag_id}/dagRuns/list:
post:
tags:
@@ -8672,6 +8733,36 @@ components:
- microseconds
title: TimeDelta
description: TimeDelta can be used to interact with datetime.timedelta
objects.
+ TriggerDAGRunPostBody:
+ properties:
+ dag_run_id:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Dag Run Id
+ data_interval_start:
+ anyOf:
+ - type: string
+ format: date-time
+ - type: 'null'
+ title: Data Interval Start
+ data_interval_end:
+ anyOf:
+ - type: string
+ format: date-time
+ - type: 'null'
+ title: Data Interval End
+ conf:
+ type: object
+ title: Conf
+ note:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Note
+ type: object
+ title: TriggerDAGRunPostBody
+ description: Trigger DAG Run Serializer for POST body.
TriggerResponse:
properties:
id:
diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py
b/airflow/api_fastapi/core_api/routes/public/dag_run.py
index 5b3eb0e66e5..1e95f75273c 100644
--- a/airflow/api_fastapi/core_api/routes/public/dag_run.py
+++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py
@@ -19,6 +19,7 @@ from __future__ import annotations
from typing import Annotated, Literal, cast
+import pendulum
from fastapi import Depends, HTTPException, Query, Request, status
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -50,13 +51,19 @@ from airflow.api_fastapi.core_api.datamodels.dag_run import
(
DAGRunPatchStates,
DAGRunResponse,
DAGRunsBatchBody,
+ TriggerDAGRunPostBody,
)
from airflow.api_fastapi.core_api.datamodels.task_instances import (
TaskInstanceCollectionResponse,
TaskInstanceResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
-from airflow.models import DAG, DagRun
+from airflow.exceptions import ParamValidationError
+from airflow.models import DAG, DagModel, DagRun
+from airflow.models.dag_version import DagVersion
+from airflow.timetables.base import DataInterval
+from airflow.utils.state import DagRunState
+from airflow.utils.types import DagRunTriggeredByType, DagRunType
dag_run_router = AirflowRouter(tags=["DagRun"],
prefix="/dags/{dag_id}/dagRuns")
@@ -303,6 +310,67 @@ def get_dag_runs(
)
+@dag_run_router.post(
+ "",
+ responses=create_openapi_http_exception_doc(
+ [
+ status.HTTP_400_BAD_REQUEST,
+ status.HTTP_404_NOT_FOUND,
+ status.HTTP_409_CONFLICT,
+ ]
+ ),
+)
+def trigger_dag_run(
+ dag_id, body: TriggerDAGRunPostBody, request: Request, session:
Annotated[Session, Depends(get_session)]
+) -> DAGRunResponse:
+ """Trigger a DAG."""
+ dm = session.scalar(select(DagModel).where(DagModel.is_active,
DagModel.dag_id == dag_id).limit(1))
+ if not dm:
+ raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with dag_id:
'{dag_id}' not found")
+
+ if dm.has_import_errors:
+ raise HTTPException(
+ status.HTTP_400_BAD_REQUEST,
+ f"DAG with dag_id: '{dag_id}' has import errors and cannot be
triggered",
+ )
+
+ run_id = body.dag_run_id
+ logical_date = pendulum.instance(body.logical_date)
+
+ try:
+ dag: DAG = request.app.state.dag_bag.get_dag(dag_id)
+
+ if body.data_interval_start and body.data_interval_end:
+ data_interval = DataInterval(
+ start=pendulum.instance(body.data_interval_start),
+ end=pendulum.instance(body.data_interval_end),
+ )
+ else:
+ data_interval =
dag.timetable.infer_manual_data_interval(run_after=logical_date)
+ dag_version = DagVersion.get_latest_version(dag.dag_id)
+ dag_run = dag.create_dagrun(
+ run_type=DagRunType.MANUAL,
+ run_id=run_id,
+ logical_date=logical_date,
+ data_interval=data_interval,
+ state=DagRunState.QUEUED,
+ conf=body.conf,
+ external_trigger=True,
+ dag_version=dag_version,
+ session=session,
+ triggered_by=DagRunTriggeredByType.REST_API,
+ )
+ dag_run_note = body.note
+ if dag_run_note:
+ current_user_id = None # refer to
https://github.com/apache/airflow/issues/43534
+ dag_run.note = (dag_run_note, current_user_id)
+ return dag_run
+ except ValueError as e:
+ raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e))
+ except ParamValidationError as e:
+ raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e))
+
+
@dag_run_router.post("/list",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]))
def get_list_dag_runs_batch(
dag_id: Literal["~"], body: DAGRunsBatchBody, session: Annotated[Session,
Depends(get_session)]
diff --git a/airflow/ui/openapi-gen/queries/common.ts
b/airflow/ui/openapi-gen/queries/common.ts
index e4bc5f300d8..cb7f5c7a537 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -1573,6 +1573,9 @@ export type ConnectionServiceTestConnectionMutationResult
= Awaited<
export type DagRunServiceClearDagRunMutationResult = Awaited<
ReturnType<typeof DagRunService.clearDagRun>
>;
+export type DagRunServiceTriggerDagRunMutationResult = Awaited<
+ ReturnType<typeof DagRunService.triggerDagRun>
+>;
export type DagRunServiceGetListDagRunsBatchMutationResult = Awaited<
ReturnType<typeof DagRunService.getListDagRunsBatch>
>;
diff --git a/airflow/ui/openapi-gen/queries/queries.ts
b/airflow/ui/openapi-gen/queries/queries.ts
index 6ff3e83ccce..d644beb7294 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -48,6 +48,7 @@ import {
PoolPostBody,
PoolPostBulkBody,
TaskInstancesBatchBody,
+ TriggerDAGRunPostBody,
VariableBody,
} from "../requests/types.gen";
import * as Common from "./common";
@@ -2726,6 +2727,49 @@ export const useDagRunServiceClearDagRun = <
}) as unknown as Promise<TData>,
...options,
});
+/**
+ * Trigger Dag Run
+ * Trigger a DAG.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.requestBody
+ * @returns DAGRunResponse Successful Response
+ * @throws ApiError
+ */
+export const useDagRunServiceTriggerDagRun = <
+ TData = Common.DagRunServiceTriggerDagRunMutationResult,
+ TError = unknown,
+ TContext = unknown,
+>(
+ options?: Omit<
+ UseMutationOptions<
+ TData,
+ TError,
+ {
+ dagId: unknown;
+ requestBody: TriggerDAGRunPostBody;
+ },
+ TContext
+ >,
+ "mutationFn"
+ >,
+) =>
+ useMutation<
+ TData,
+ TError,
+ {
+ dagId: unknown;
+ requestBody: TriggerDAGRunPostBody;
+ },
+ TContext
+ >({
+ mutationFn: ({ dagId, requestBody }) =>
+ DagRunService.triggerDagRun({
+ dagId,
+ requestBody,
+ }) as unknown as Promise<TData>,
+ ...options,
+ });
/**
* Get List Dag Runs Batch
* Get a list of DAG Runs.
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index 8002b9d37f6..f657ba8d4e2 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -4852,6 +4852,64 @@ export const $TimeDelta = {
"TimeDelta can be used to interact with datetime.timedelta objects.",
} as const;
+export const $TriggerDAGRunPostBody = {
+ properties: {
+ dag_run_id: {
+ anyOf: [
+ {
+ type: "string",
+ },
+ {
+ type: "null",
+ },
+ ],
+ title: "Dag Run Id",
+ },
+ data_interval_start: {
+ anyOf: [
+ {
+ type: "string",
+ format: "date-time",
+ },
+ {
+ type: "null",
+ },
+ ],
+ title: "Data Interval Start",
+ },
+ data_interval_end: {
+ anyOf: [
+ {
+ type: "string",
+ format: "date-time",
+ },
+ {
+ type: "null",
+ },
+ ],
+ title: "Data Interval End",
+ },
+ conf: {
+ type: "object",
+ title: "Conf",
+ },
+ note: {
+ anyOf: [
+ {
+ type: "string",
+ },
+ {
+ type: "null",
+ },
+ ],
+ title: "Note",
+ },
+ },
+ type: "object",
+ title: "TriggerDAGRunPostBody",
+ description: "Trigger DAG Run Serializer for POST body.",
+} as const;
+
export const $TriggerResponse = {
properties: {
id: {
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts
b/airflow/ui/openapi-gen/requests/services.gen.ts
index d8cb33bceec..4c8944447ce 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -70,6 +70,8 @@ import type {
ClearDagRunResponse,
GetDagRunsData,
GetDagRunsResponse,
+ TriggerDagRunData,
+ TriggerDagRunResponse,
GetListDagRunsBatchData,
GetListDagRunsBatchResponse,
GetDagSourceData,
@@ -1193,6 +1195,37 @@ export class DagRunService {
});
}
+ /**
+ * Trigger Dag Run
+ * Trigger a DAG.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.requestBody
+ * @returns DAGRunResponse Successful Response
+ * @throws ApiError
+ */
+ public static triggerDagRun(
+ data: TriggerDagRunData,
+ ): CancelablePromise<TriggerDagRunResponse> {
+ return __request(OpenAPI, {
+ method: "POST",
+ url: "/public/dags/{dag_id}/dagRuns",
+ path: {
+ dag_id: data.dagId,
+ },
+ body: data.requestBody,
+ mediaType: "application/json",
+ errors: {
+ 400: "Bad Request",
+ 401: "Unauthorized",
+ 403: "Forbidden",
+ 404: "Not Found",
+ 409: "Conflict",
+ 422: "Validation Error",
+ },
+ });
+ }
+
/**
* Get List Dag Runs Batch
* Get a list of DAG Runs.
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts
b/airflow/ui/openapi-gen/requests/types.gen.ts
index bdcce0157dc..e15de90441b 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -1126,6 +1126,19 @@ export type TimeDelta = {
microseconds: number;
};
+/**
+ * Trigger DAG Run Serializer for POST body.
+ */
+export type TriggerDAGRunPostBody = {
+ dag_run_id?: string | null;
+ data_interval_start?: string | null;
+ data_interval_end?: string | null;
+ conf?: {
+ [key: string]: unknown;
+ };
+ note?: string | null;
+};
+
/**
* Trigger serializer for responses.
*/
@@ -1494,6 +1507,13 @@ export type GetDagRunsData = {
export type GetDagRunsResponse = DAGRunCollectionResponse;
+export type TriggerDagRunData = {
+ dagId: unknown;
+ requestBody: TriggerDAGRunPostBody;
+};
+
+export type TriggerDagRunResponse = DAGRunResponse;
+
export type GetListDagRunsBatchData = {
dagId: "~";
requestBody: DAGRunsBatchBody;
@@ -2853,6 +2873,39 @@ export type $OpenApiTs = {
422: HTTPValidationError;
};
};
+ post: {
+ req: TriggerDagRunData;
+ res: {
+ /**
+ * Successful Response
+ */
+ 200: DAGRunResponse;
+ /**
+ * Bad Request
+ */
+ 400: HTTPExceptionResponse;
+ /**
+ * Unauthorized
+ */
+ 401: HTTPExceptionResponse;
+ /**
+ * Forbidden
+ */
+ 403: HTTPExceptionResponse;
+ /**
+ * Not Found
+ */
+ 404: HTTPExceptionResponse;
+ /**
+ * Conflict
+ */
+ 409: HTTPExceptionResponse;
+ /**
+ * Validation Error
+ */
+ 422: HTTPValidationError;
+ };
+ };
};
"/public/dags/{dag_id}/dagRuns/list": {
post: {
diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py
b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
index 6e9f2b69eb2..d453c973c8d 100644
--- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py
+++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
@@ -17,15 +17,19 @@
from __future__ import annotations
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta
+from unittest import mock
import pytest
+import time_machine
from sqlalchemy import select
-from airflow.models import DagRun
+from airflow.models import DagModel, DagRun
from airflow.models.asset import AssetEvent, AssetModel
+from airflow.models.param import Param
from airflow.operators.empty import EmptyOperator
from airflow.sdk.definitions.asset import Asset
+from airflow.utils import timezone
from airflow.utils.session import provide_session
from airflow.utils.state import DagRunState, State
from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -63,20 +67,24 @@ START_DATE2 = datetime(2024, 4, 15, 0, 0,
tzinfo=timezone.utc)
LOGICAL_DATE3 = datetime(2024, 5, 16, 0, 0, tzinfo=timezone.utc)
LOGICAL_DATE4 = datetime(2024, 5, 25, 0, 0, tzinfo=timezone.utc)
DAG1_RUN1_NOTE = "test_note"
+DAG2_PARAM = {"validated_number": Param(1, minimum=1, maximum=10)}
DAG_RUNS_LIST = [DAG1_RUN1_ID, DAG1_RUN2_ID, DAG2_RUN1_ID, DAG2_RUN2_ID]
@pytest.fixture(autouse=True)
@provide_session
-def setup(dag_maker, session=None):
+def setup(request, dag_maker, session=None):
clear_db_runs()
clear_db_dags()
clear_db_serialized_dags()
+ if "no_setup" in request.keywords:
+ return
+
with dag_maker(
DAG1_ID,
- schedule="@daily",
+ schedule=None,
start_date=START_DATE1,
):
task1 = EmptyOperator(task_id="task_1")
@@ -102,11 +110,7 @@ def setup(dag_maker, session=None):
logical_date=LOGICAL_DATE2,
)
- with dag_maker(
- DAG2_ID,
- schedule=None,
- start_date=START_DATE2,
- ):
+ with dag_maker(DAG2_ID, schedule=None, start_date=START_DATE2,
params=DAG2_PARAM):
EmptyOperator(task_id="task_2")
dag_maker.create_dagrun(
run_id=DAG2_RUN1_ID,
@@ -1048,3 +1052,308 @@ class TestClearDagRun:
body = response.json()
assert body["detail"][0]["msg"] == "Field required"
assert body["detail"][0]["loc"][0] == "body"
+
+
+class TestTriggerDagRun:
+ def _dags_for_trigger_tests(self, session=None):
+ inactive_dag = DagModel(
+ dag_id="inactive",
+ fileloc="/tmp/dag_del_1.py",
+ timetable_summary="2 2 * * *",
+ is_active=False,
+ is_paused=True,
+ owners="test_owner,another_test_owner",
+ next_dagrun=datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
+ )
+
+ import_errors_dag = DagModel(
+ dag_id="import_errors",
+ fileloc="/tmp/dag_del_2.py",
+ timetable_summary="2 2 * * *",
+ is_active=True,
+ owners="test_owner,another_test_owner",
+ next_dagrun=datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
+ )
+ import_errors_dag.has_import_errors = True
+
+ session.add(inactive_dag)
+ session.add(import_errors_dag)
+ session.commit()
+
+ @time_machine.travel(timezone.utcnow(), tick=False)
+ @pytest.mark.parametrize(
+ "dag_run_id, note, data_interval_start, data_interval_end",
+ [
+ ("dag_run_5", "test-note", None, None),
+ (
+ "dag_run_6",
+ "test-note",
+ "2024-01-03T00:00:00+00:00",
+ "2024-01-04T05:00:00+00:00",
+ ),
+ (None, None, None, None),
+ ],
+ )
+ def test_should_respond_200(
+ self,
+ test_client,
+ dag_run_id,
+ note,
+ data_interval_start,
+ data_interval_end,
+ ):
+ fixed_now = timezone.utcnow().isoformat()
+
+ request_json = {"note": note}
+ if dag_run_id is not None:
+ request_json["dag_run_id"] = dag_run_id
+ if data_interval_start is not None:
+ request_json["data_interval_start"] = data_interval_start
+ if data_interval_end is not None:
+ request_json["data_interval_end"] = data_interval_end
+
+ response = test_client.post(
+ f"/public/dags/{DAG1_ID}/dagRuns",
+ json=request_json,
+ )
+ assert response.status_code == 200
+
+ if dag_run_id is None:
+ expected_dag_run_id = f"manual__{fixed_now}"
+ else:
+ expected_dag_run_id = dag_run_id
+
+ expected_data_interval_start = fixed_now.replace("+00:00", "Z")
+ expected_data_interval_end = fixed_now.replace("+00:00", "Z")
+ if data_interval_start is not None and data_interval_end is not None:
+ expected_data_interval_start =
data_interval_start.replace("+00:00", "Z")
+ expected_data_interval_end = data_interval_end.replace("+00:00",
"Z")
+
+ expected_response_json = {
+ "conf": {},
+ "dag_id": DAG1_ID,
+ "dag_run_id": expected_dag_run_id,
+ "end_date": None,
+ "logical_date": fixed_now.replace("+00:00", "Z"),
+ "external_trigger": True,
+ "start_date": None,
+ "state": "queued",
+ "data_interval_end": expected_data_interval_end,
+ "data_interval_start": expected_data_interval_start,
+ "queued_at": fixed_now.replace("+00:00", "Z"),
+ "last_scheduling_decision": None,
+ "run_type": "manual",
+ "note": note,
+ "triggered_by": "rest_api",
+ }
+
+ assert response.json() == expected_response_json
+
+ @pytest.mark.parametrize(
+ "post_body, expected_detail",
+ [
+ # Uncomment these 2 test cases once
https://github.com/apache/airflow/pull/44306 is merged
+ # (
+ # {"executiondate": "2020-11-10T08:25:56Z"},
+ # {
+ # "detail": [
+ # {
+ # "input": "2020-11-10T08:25:56Z",
+ # "loc": ["body", "executiondate"],
+ # "msg": "Extra inputs are not permitted",
+ # "type": "extra_forbidden",
+ # }
+ # ]
+ # },
+ # ),
+ # (
+ # {"logical_date": "2020-11-10T08:25:56"},
+ # {
+ # "detail": [
+ # {
+ # "input": "2020-11-10T08:25:56",
+ # "loc": ["body", "logical_date"],
+ # "msg": "Extra inputs are not permitted",
+ # "type": "extra_forbidden",
+ # }
+ # ]
+ # },
+ # ),
+ (
+ {"data_interval_start": "2020-11-10T08:25:56"},
+ {
+ "detail": [
+ {
+ "input": "2020-11-10T08:25:56",
+ "loc": ["body", "data_interval_start"],
+ "msg": "Input should have timezone info",
+ "type": "timezone_aware",
+ }
+ ]
+ },
+ ),
+ (
+ {"data_interval_end": "2020-11-10T08:25:56"},
+ {
+ "detail": [
+ {
+ "input": "2020-11-10T08:25:56",
+ "loc": ["body", "data_interval_end"],
+ "msg": "Input should have timezone info",
+ "type": "timezone_aware",
+ }
+ ]
+ },
+ ),
+ (
+ {"dag_run_id": 20},
+ {
+ "detail": [
+ {
+ "input": 20,
+ "loc": ["body", "dag_run_id"],
+ "msg": "Input should be a valid string",
+ "type": "string_type",
+ }
+ ]
+ },
+ ),
+ (
+ {"note": 20},
+ {
+ "detail": [
+ {
+ "input": 20,
+ "loc": ["body", "note"],
+ "msg": "Input should be a valid string",
+ "type": "string_type",
+ }
+ ]
+ },
+ ),
+ (
+ {"conf": 20},
+ {
+ "detail": [
+ {
+ "input": 20,
+ "loc": ["body", "conf"],
+ "msg": "Input should be a valid dictionary",
+ "type": "dict_type",
+ }
+ ]
+ },
+ ),
+ ],
+ )
+ def test_invalid_data(self, test_client, post_body, expected_detail):
+ response = test_client.post(f"/public/dags/{DAG1_ID}/dagRuns",
json=post_body)
+ assert response.status_code == 422
+ assert response.json() == expected_detail
+
+ @mock.patch("airflow.models.DAG.create_dagrun")
+ def test_dagrun_creation_exception_is_handled(self, mock_create_dagrun,
test_client):
+ error_message = "Encountered Error"
+
+ mock_create_dagrun.side_effect = ValueError(error_message)
+
+ response = test_client.post(f"/public/dags/{DAG1_ID}/dagRuns", json={})
+ assert response.status_code == 400
+ assert response.json() == {"detail": error_message}
+
+ def test_should_respond_404_if_a_dag_is_inactive(self, test_client,
session):
+ self._dags_for_trigger_tests(session)
+ response = test_client.post("/public/dags/inactive/dagRuns", json={})
+ assert response.status_code == 404
+ assert response.json()["detail"] == "DAG with dag_id: 'inactive' not
found"
+
+ def test_should_respond_400_if_a_dag_has_import_errors(self, test_client,
session):
+ self._dags_for_trigger_tests(session)
+ response = test_client.post("/public/dags/import_errors/dagRuns",
json={})
+ assert response.status_code == 400
+ assert (
+ response.json()["detail"]
+ == "DAG with dag_id: 'import_errors' has import errors and cannot
be triggered"
+ )
+
+ @time_machine.travel(timezone.utcnow(), tick=False)
+ def test_should_response_200_for_duplicate_logical_date(self, test_client):
+ RUN_ID_1 = "random_1"
+ RUN_ID_2 = "random_2"
+ now = timezone.utcnow().isoformat().replace("+00:00", "Z")
+ note = "duplicate logical date test"
+ response_1 = test_client.post(
+ f"/public/dags/{DAG1_ID}/dagRuns",
+ json={"dag_run_id": RUN_ID_1, "note": note},
+ )
+ response_2 = test_client.post(
+ f"/public/dags/{DAG1_ID}/dagRuns",
+ json={"dag_run_id": RUN_ID_2, "note": note},
+ )
+
+ assert response_1.status_code == response_2.status_code == 200
+ body1 = response_1.json()
+ body2 = response_2.json()
+
+ for each_run_id, each_body in [(RUN_ID_1, body1), (RUN_ID_2, body2)]:
+ assert each_body == {
+ "dag_run_id": each_run_id,
+ "dag_id": DAG1_ID,
+ "logical_date": now,
+ "queued_at": now,
+ "start_date": None,
+ "end_date": None,
+ "data_interval_start": now,
+ "data_interval_end": now,
+ "last_scheduling_decision": None,
+ "run_type": "manual",
+ "state": "queued",
+ "external_trigger": True,
+ "triggered_by": "rest_api",
+ "conf": {},
+ "note": note,
+ }
+
+ @pytest.mark.parametrize(
+ "data_interval_start, data_interval_end",
+ [
+ (
+ LOGICAL_DATE1.isoformat(),
+ None,
+ ),
+ (
+ None,
+ LOGICAL_DATE1.isoformat(),
+ ),
+ ],
+ )
+ def test_should_response_422_for_missing_start_date_or_end_date(
+ self, test_client, data_interval_start, data_interval_end
+ ):
+ response = test_client.post(
+ f"/public/dags/{DAG1_ID}/dagRuns",
+ json={"data_interval_start": data_interval_start,
"data_interval_end": data_interval_end},
+ )
+ assert response.status_code == 422
+ assert (
+ response.json()["detail"][0]["msg"]
+ == "Value error, Either both data_interval_start and
data_interval_end must be provided or both must be None"
+ )
+
+ def test_raises_validation_error_for_invalid_params(self, test_client):
+ response = test_client.post(
+ f"/public/dags/{DAG2_ID}/dagRuns",
+ json={"conf": {"validated_number": 5000}},
+ )
+ assert response.status_code == 400
+ assert "Invalid input for param validated_number" in
response.json()["detail"]
+
+ def test_response_404(self, test_client):
+ response = test_client.post("/public/dags/randoms/dagRuns", json={})
+ assert response.status_code == 404
+ assert response.json()["detail"] == "DAG with dag_id: 'randoms' not
found"
+
+ def test_response_409(self, test_client):
+ response = test_client.post(f"/public/dags/{DAG1_ID}/dagRuns",
json={"dag_run_id": DAG1_RUN1_ID})
+ assert response.status_code == 409
+ assert response.json()["detail"] == "Unique constraint violation"