This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a1fbdb3c1eb AIP-84 Migrate Trigger Dag Run endpoint to FastAPI  
(#43875)
a1fbdb3c1eb is described below

commit a1fbdb3c1ebe1ef7964dfbf71d7833ff2451482d
Author: Kalyan R <[email protected]>
AuthorDate: Wed Nov 27 05:32:48 2024 +0530

    AIP-84 Migrate Trigger Dag Run endpoint to FastAPI  (#43875)
    
    * init
    
    * wip
    
    * remove logical_date
    
    * fix trigger dag_run
    
    * tests WIP
    
    * working tests
    
    * remove logical_date from post body
    
    * remove logical_date from tests
    
    * fix
    
    * include return type
    
    * fix conf
    
    * feedback
    
    * fix tests
    
    * Update tests/api_fastapi/core_api/routes/public/test_dag_run.py
    
    * feedback
---
 .../api_connexion/endpoints/dag_run_endpoint.py    |   1 +
 airflow/api_fastapi/core_api/datamodels/dag_run.py |  35 ++-
 .../api_fastapi/core_api/openapi/v1-generated.yaml |  91 ++++++
 .../api_fastapi/core_api/routes/public/dag_run.py  |  70 ++++-
 airflow/ui/openapi-gen/queries/common.ts           |   3 +
 airflow/ui/openapi-gen/queries/queries.ts          |  44 +++
 airflow/ui/openapi-gen/requests/schemas.gen.ts     |  58 ++++
 airflow/ui/openapi-gen/requests/services.gen.ts    |  33 +++
 airflow/ui/openapi-gen/requests/types.gen.ts       |  53 ++++
 .../core_api/routes/public/test_dag_run.py         | 327 ++++++++++++++++++++-
 10 files changed, 704 insertions(+), 11 deletions(-)

diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py 
b/airflow/api_connexion/endpoints/dag_run_endpoint.py
index dadfb3e4f42..985efc7fc89 100644
--- a/airflow/api_connexion/endpoints/dag_run_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py
@@ -305,6 +305,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) 
-> APIResponse:
     return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, 
total_entries=total_entries))
 
 
+@mark_fastapi_migration_done
 @security.requires_access_dag("POST", DagAccessEntity.RUN)
 @action_logging
 @provide_session
diff --git a/airflow/api_fastapi/core_api/datamodels/dag_run.py 
b/airflow/api_fastapi/core_api/datamodels/dag_run.py
index 55240d15e55..ab812627787 100644
--- a/airflow/api_fastapi/core_api/datamodels/dag_run.py
+++ b/airflow/api_fastapi/core_api/datamodels/dag_run.py
@@ -20,9 +20,11 @@ from __future__ import annotations
 from datetime import datetime
 from enum import Enum
 
-from pydantic import AwareDatetime, Field, NonNegativeInt
+from pydantic import AwareDatetime, Field, NonNegativeInt, computed_field, 
model_validator
 
 from airflow.api_fastapi.core_api.base import BaseModel
+from airflow.models import DagRun
+from airflow.utils import timezone
 from airflow.utils.state import DagRunState
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
@@ -75,6 +77,37 @@ class DAGRunCollectionResponse(BaseModel):
     total_entries: int
 
 
+class TriggerDAGRunPostBody(BaseModel):
+    """Trigger DAG Run Serializer for POST body."""
+
+    dag_run_id: str | None = None
+    data_interval_start: AwareDatetime | None = None
+    data_interval_end: AwareDatetime | None = None
+
+    conf: dict = Field(default_factory=dict)
+    note: str | None = None
+
+    @model_validator(mode="after")
+    def check_data_intervals(cls, values):
+        if (values.data_interval_start is None) != (values.data_interval_end 
is None):
+            raise ValueError(
+                "Either both data_interval_start and data_interval_end must be 
provided or both must be None"
+            )
+        return values
+
+    @model_validator(mode="after")
+    def validate_dag_run_id(self):
+        if not self.dag_run_id:
+            self.dag_run_id = DagRun.generate_run_id(DagRunType.MANUAL, 
self.logical_date)
+        return self
+
+    # Mypy issue https://github.com/python/mypy/issues/1362
+    @computed_field  # type: ignore[misc]
+    @property
+    def logical_date(self) -> datetime:
+        return timezone.utcnow()
+
+
 class DAGRunsBatchBody(BaseModel):
     """List DAG Runs body for batch endpoint."""
 
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml 
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index c53f68a8438..46fd0382eb9 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -1828,6 +1828,67 @@ paths:
             application/json:
               schema:
                 $ref: '#/components/schemas/HTTPValidationError'
+    post:
+      tags:
+      - DagRun
+      summary: Trigger Dag Run
+      description: Trigger a DAG.
+      operationId: trigger_dag_run
+      parameters:
+      - name: dag_id
+        in: path
+        required: true
+        schema:
+          title: Dag Id
+      requestBody:
+        required: true
+        content:
+          application/json:
+            schema:
+              $ref: '#/components/schemas/TriggerDAGRunPostBody'
+      responses:
+        '200':
+          description: Successful Response
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/DAGRunResponse'
+        '401':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Unauthorized
+        '403':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Forbidden
+        '400':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Bad Request
+        '404':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Not Found
+        '409':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Conflict
+        '422':
+          description: Validation Error
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPValidationError'
   /public/dags/{dag_id}/dagRuns/list:
     post:
       tags:
@@ -8672,6 +8733,36 @@ components:
       - microseconds
       title: TimeDelta
       description: TimeDelta can be used to interact with datetime.timedelta 
objects.
+    TriggerDAGRunPostBody:
+      properties:
+        dag_run_id:
+          anyOf:
+          - type: string
+          - type: 'null'
+          title: Dag Run Id
+        data_interval_start:
+          anyOf:
+          - type: string
+            format: date-time
+          - type: 'null'
+          title: Data Interval Start
+        data_interval_end:
+          anyOf:
+          - type: string
+            format: date-time
+          - type: 'null'
+          title: Data Interval End
+        conf:
+          type: object
+          title: Conf
+        note:
+          anyOf:
+          - type: string
+          - type: 'null'
+          title: Note
+      type: object
+      title: TriggerDAGRunPostBody
+      description: Trigger DAG Run Serializer for POST body.
     TriggerResponse:
       properties:
         id:
diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py 
b/airflow/api_fastapi/core_api/routes/public/dag_run.py
index 5b3eb0e66e5..1e95f75273c 100644
--- a/airflow/api_fastapi/core_api/routes/public/dag_run.py
+++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 from typing import Annotated, Literal, cast
 
+import pendulum
 from fastapi import Depends, HTTPException, Query, Request, status
 from sqlalchemy import select
 from sqlalchemy.orm import Session
@@ -50,13 +51,19 @@ from airflow.api_fastapi.core_api.datamodels.dag_run import 
(
     DAGRunPatchStates,
     DAGRunResponse,
     DAGRunsBatchBody,
+    TriggerDAGRunPostBody,
 )
 from airflow.api_fastapi.core_api.datamodels.task_instances import (
     TaskInstanceCollectionResponse,
     TaskInstanceResponse,
 )
 from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_exception_doc
-from airflow.models import DAG, DagRun
+from airflow.exceptions import ParamValidationError
+from airflow.models import DAG, DagModel, DagRun
+from airflow.models.dag_version import DagVersion
+from airflow.timetables.base import DataInterval
+from airflow.utils.state import DagRunState
+from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
 dag_run_router = AirflowRouter(tags=["DagRun"], 
prefix="/dags/{dag_id}/dagRuns")
 
@@ -303,6 +310,67 @@ def get_dag_runs(
     )
 
 
+@dag_run_router.post(
+    "",
+    responses=create_openapi_http_exception_doc(
+        [
+            status.HTTP_400_BAD_REQUEST,
+            status.HTTP_404_NOT_FOUND,
+            status.HTTP_409_CONFLICT,
+        ]
+    ),
+)
+def trigger_dag_run(
+    dag_id, body: TriggerDAGRunPostBody, request: Request, session: 
Annotated[Session, Depends(get_session)]
+) -> DAGRunResponse:
+    """Trigger a DAG."""
+    dm = session.scalar(select(DagModel).where(DagModel.is_active, 
DagModel.dag_id == dag_id).limit(1))
+    if not dm:
+        raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with dag_id: 
'{dag_id}' not found")
+
+    if dm.has_import_errors:
+        raise HTTPException(
+            status.HTTP_400_BAD_REQUEST,
+            f"DAG with dag_id: '{dag_id}' has import errors and cannot be 
triggered",
+        )
+
+    run_id = body.dag_run_id
+    logical_date = pendulum.instance(body.logical_date)
+
+    try:
+        dag: DAG = request.app.state.dag_bag.get_dag(dag_id)
+
+        if body.data_interval_start and body.data_interval_end:
+            data_interval = DataInterval(
+                start=pendulum.instance(body.data_interval_start),
+                end=pendulum.instance(body.data_interval_end),
+            )
+        else:
+            data_interval = 
dag.timetable.infer_manual_data_interval(run_after=logical_date)
+        dag_version = DagVersion.get_latest_version(dag.dag_id)
+        dag_run = dag.create_dagrun(
+            run_type=DagRunType.MANUAL,
+            run_id=run_id,
+            logical_date=logical_date,
+            data_interval=data_interval,
+            state=DagRunState.QUEUED,
+            conf=body.conf,
+            external_trigger=True,
+            dag_version=dag_version,
+            session=session,
+            triggered_by=DagRunTriggeredByType.REST_API,
+        )
+        dag_run_note = body.note
+        if dag_run_note:
+            current_user_id = None  # refer to 
https://github.com/apache/airflow/issues/43534
+            dag_run.note = (dag_run_note, current_user_id)
+        return dag_run
+    except ValueError as e:
+        raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e))
+    except ParamValidationError as e:
+        raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e))
+
+
 @dag_run_router.post("/list", 
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]))
 def get_list_dag_runs_batch(
     dag_id: Literal["~"], body: DAGRunsBatchBody, session: Annotated[Session, 
Depends(get_session)]
diff --git a/airflow/ui/openapi-gen/queries/common.ts 
b/airflow/ui/openapi-gen/queries/common.ts
index e4bc5f300d8..cb7f5c7a537 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -1573,6 +1573,9 @@ export type ConnectionServiceTestConnectionMutationResult 
= Awaited<
 export type DagRunServiceClearDagRunMutationResult = Awaited<
   ReturnType<typeof DagRunService.clearDagRun>
 >;
+export type DagRunServiceTriggerDagRunMutationResult = Awaited<
+  ReturnType<typeof DagRunService.triggerDagRun>
+>;
 export type DagRunServiceGetListDagRunsBatchMutationResult = Awaited<
   ReturnType<typeof DagRunService.getListDagRunsBatch>
 >;
diff --git a/airflow/ui/openapi-gen/queries/queries.ts 
b/airflow/ui/openapi-gen/queries/queries.ts
index 6ff3e83ccce..d644beb7294 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -48,6 +48,7 @@ import {
   PoolPostBody,
   PoolPostBulkBody,
   TaskInstancesBatchBody,
+  TriggerDAGRunPostBody,
   VariableBody,
 } from "../requests/types.gen";
 import * as Common from "./common";
@@ -2726,6 +2727,49 @@ export const useDagRunServiceClearDagRun = <
       }) as unknown as Promise<TData>,
     ...options,
   });
+/**
+ * Trigger Dag Run
+ * Trigger a DAG.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.requestBody
+ * @returns DAGRunResponse Successful Response
+ * @throws ApiError
+ */
+export const useDagRunServiceTriggerDagRun = <
+  TData = Common.DagRunServiceTriggerDagRunMutationResult,
+  TError = unknown,
+  TContext = unknown,
+>(
+  options?: Omit<
+    UseMutationOptions<
+      TData,
+      TError,
+      {
+        dagId: unknown;
+        requestBody: TriggerDAGRunPostBody;
+      },
+      TContext
+    >,
+    "mutationFn"
+  >,
+) =>
+  useMutation<
+    TData,
+    TError,
+    {
+      dagId: unknown;
+      requestBody: TriggerDAGRunPostBody;
+    },
+    TContext
+  >({
+    mutationFn: ({ dagId, requestBody }) =>
+      DagRunService.triggerDagRun({
+        dagId,
+        requestBody,
+      }) as unknown as Promise<TData>,
+    ...options,
+  });
 /**
  * Get List Dag Runs Batch
  * Get a list of DAG Runs.
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts 
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index 8002b9d37f6..f657ba8d4e2 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -4852,6 +4852,64 @@ export const $TimeDelta = {
     "TimeDelta can be used to interact with datetime.timedelta objects.",
 } as const;
 
+export const $TriggerDAGRunPostBody = {
+  properties: {
+    dag_run_id: {
+      anyOf: [
+        {
+          type: "string",
+        },
+        {
+          type: "null",
+        },
+      ],
+      title: "Dag Run Id",
+    },
+    data_interval_start: {
+      anyOf: [
+        {
+          type: "string",
+          format: "date-time",
+        },
+        {
+          type: "null",
+        },
+      ],
+      title: "Data Interval Start",
+    },
+    data_interval_end: {
+      anyOf: [
+        {
+          type: "string",
+          format: "date-time",
+        },
+        {
+          type: "null",
+        },
+      ],
+      title: "Data Interval End",
+    },
+    conf: {
+      type: "object",
+      title: "Conf",
+    },
+    note: {
+      anyOf: [
+        {
+          type: "string",
+        },
+        {
+          type: "null",
+        },
+      ],
+      title: "Note",
+    },
+  },
+  type: "object",
+  title: "TriggerDAGRunPostBody",
+  description: "Trigger DAG Run Serializer for POST body.",
+} as const;
+
 export const $TriggerResponse = {
   properties: {
     id: {
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts 
b/airflow/ui/openapi-gen/requests/services.gen.ts
index d8cb33bceec..4c8944447ce 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -70,6 +70,8 @@ import type {
   ClearDagRunResponse,
   GetDagRunsData,
   GetDagRunsResponse,
+  TriggerDagRunData,
+  TriggerDagRunResponse,
   GetListDagRunsBatchData,
   GetListDagRunsBatchResponse,
   GetDagSourceData,
@@ -1193,6 +1195,37 @@ export class DagRunService {
     });
   }
 
+  /**
+   * Trigger Dag Run
+   * Trigger a DAG.
+   * @param data The data for the request.
+   * @param data.dagId
+   * @param data.requestBody
+   * @returns DAGRunResponse Successful Response
+   * @throws ApiError
+   */
+  public static triggerDagRun(
+    data: TriggerDagRunData,
+  ): CancelablePromise<TriggerDagRunResponse> {
+    return __request(OpenAPI, {
+      method: "POST",
+      url: "/public/dags/{dag_id}/dagRuns",
+      path: {
+        dag_id: data.dagId,
+      },
+      body: data.requestBody,
+      mediaType: "application/json",
+      errors: {
+        400: "Bad Request",
+        401: "Unauthorized",
+        403: "Forbidden",
+        404: "Not Found",
+        409: "Conflict",
+        422: "Validation Error",
+      },
+    });
+  }
+
   /**
    * Get List Dag Runs Batch
    * Get a list of DAG Runs.
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts 
b/airflow/ui/openapi-gen/requests/types.gen.ts
index bdcce0157dc..e15de90441b 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -1126,6 +1126,19 @@ export type TimeDelta = {
   microseconds: number;
 };
 
+/**
+ * Trigger DAG Run Serializer for POST body.
+ */
+export type TriggerDAGRunPostBody = {
+  dag_run_id?: string | null;
+  data_interval_start?: string | null;
+  data_interval_end?: string | null;
+  conf?: {
+    [key: string]: unknown;
+  };
+  note?: string | null;
+};
+
 /**
  * Trigger serializer for responses.
  */
@@ -1494,6 +1507,13 @@ export type GetDagRunsData = {
 
 export type GetDagRunsResponse = DAGRunCollectionResponse;
 
+export type TriggerDagRunData = {
+  dagId: unknown;
+  requestBody: TriggerDAGRunPostBody;
+};
+
+export type TriggerDagRunResponse = DAGRunResponse;
+
 export type GetListDagRunsBatchData = {
   dagId: "~";
   requestBody: DAGRunsBatchBody;
@@ -2853,6 +2873,39 @@ export type $OpenApiTs = {
         422: HTTPValidationError;
       };
     };
+    post: {
+      req: TriggerDagRunData;
+      res: {
+        /**
+         * Successful Response
+         */
+        200: DAGRunResponse;
+        /**
+         * Bad Request
+         */
+        400: HTTPExceptionResponse;
+        /**
+         * Unauthorized
+         */
+        401: HTTPExceptionResponse;
+        /**
+         * Forbidden
+         */
+        403: HTTPExceptionResponse;
+        /**
+         * Not Found
+         */
+        404: HTTPExceptionResponse;
+        /**
+         * Conflict
+         */
+        409: HTTPExceptionResponse;
+        /**
+         * Validation Error
+         */
+        422: HTTPValidationError;
+      };
+    };
   };
   "/public/dags/{dag_id}/dagRuns/list": {
     post: {
diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py 
b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
index 6e9f2b69eb2..d453c973c8d 100644
--- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py
+++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
@@ -17,15 +17,19 @@
 
 from __future__ import annotations
 
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta
+from unittest import mock
 
 import pytest
+import time_machine
 from sqlalchemy import select
 
-from airflow.models import DagRun
+from airflow.models import DagModel, DagRun
 from airflow.models.asset import AssetEvent, AssetModel
+from airflow.models.param import Param
 from airflow.operators.empty import EmptyOperator
 from airflow.sdk.definitions.asset import Asset
+from airflow.utils import timezone
 from airflow.utils.session import provide_session
 from airflow.utils.state import DagRunState, State
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -63,20 +67,24 @@ START_DATE2 = datetime(2024, 4, 15, 0, 0, 
tzinfo=timezone.utc)
 LOGICAL_DATE3 = datetime(2024, 5, 16, 0, 0, tzinfo=timezone.utc)
 LOGICAL_DATE4 = datetime(2024, 5, 25, 0, 0, tzinfo=timezone.utc)
 DAG1_RUN1_NOTE = "test_note"
+DAG2_PARAM = {"validated_number": Param(1, minimum=1, maximum=10)}
 
 DAG_RUNS_LIST = [DAG1_RUN1_ID, DAG1_RUN2_ID, DAG2_RUN1_ID, DAG2_RUN2_ID]
 
 
 @pytest.fixture(autouse=True)
 @provide_session
-def setup(dag_maker, session=None):
+def setup(request, dag_maker, session=None):
     clear_db_runs()
     clear_db_dags()
     clear_db_serialized_dags()
 
+    if "no_setup" in request.keywords:
+        return
+
     with dag_maker(
         DAG1_ID,
-        schedule="@daily",
+        schedule=None,
         start_date=START_DATE1,
     ):
         task1 = EmptyOperator(task_id="task_1")
@@ -102,11 +110,7 @@ def setup(dag_maker, session=None):
         logical_date=LOGICAL_DATE2,
     )
 
-    with dag_maker(
-        DAG2_ID,
-        schedule=None,
-        start_date=START_DATE2,
-    ):
+    with dag_maker(DAG2_ID, schedule=None, start_date=START_DATE2, 
params=DAG2_PARAM):
         EmptyOperator(task_id="task_2")
     dag_maker.create_dagrun(
         run_id=DAG2_RUN1_ID,
@@ -1048,3 +1052,308 @@ class TestClearDagRun:
         body = response.json()
         assert body["detail"][0]["msg"] == "Field required"
         assert body["detail"][0]["loc"][0] == "body"
+
+
+class TestTriggerDagRun:
+    def _dags_for_trigger_tests(self, session=None):
+        inactive_dag = DagModel(
+            dag_id="inactive",
+            fileloc="/tmp/dag_del_1.py",
+            timetable_summary="2 2 * * *",
+            is_active=False,
+            is_paused=True,
+            owners="test_owner,another_test_owner",
+            next_dagrun=datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
+        )
+
+        import_errors_dag = DagModel(
+            dag_id="import_errors",
+            fileloc="/tmp/dag_del_2.py",
+            timetable_summary="2 2 * * *",
+            is_active=True,
+            owners="test_owner,another_test_owner",
+            next_dagrun=datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
+        )
+        import_errors_dag.has_import_errors = True
+
+        session.add(inactive_dag)
+        session.add(import_errors_dag)
+        session.commit()
+
+    @time_machine.travel(timezone.utcnow(), tick=False)
+    @pytest.mark.parametrize(
+        "dag_run_id, note, data_interval_start, data_interval_end",
+        [
+            ("dag_run_5", "test-note", None, None),
+            (
+                "dag_run_6",
+                "test-note",
+                "2024-01-03T00:00:00+00:00",
+                "2024-01-04T05:00:00+00:00",
+            ),
+            (None, None, None, None),
+        ],
+    )
+    def test_should_respond_200(
+        self,
+        test_client,
+        dag_run_id,
+        note,
+        data_interval_start,
+        data_interval_end,
+    ):
+        fixed_now = timezone.utcnow().isoformat()
+
+        request_json = {"note": note}
+        if dag_run_id is not None:
+            request_json["dag_run_id"] = dag_run_id
+        if data_interval_start is not None:
+            request_json["data_interval_start"] = data_interval_start
+        if data_interval_end is not None:
+            request_json["data_interval_end"] = data_interval_end
+
+        response = test_client.post(
+            f"/public/dags/{DAG1_ID}/dagRuns",
+            json=request_json,
+        )
+        assert response.status_code == 200
+
+        if dag_run_id is None:
+            expected_dag_run_id = f"manual__{fixed_now}"
+        else:
+            expected_dag_run_id = dag_run_id
+
+        expected_data_interval_start = fixed_now.replace("+00:00", "Z")
+        expected_data_interval_end = fixed_now.replace("+00:00", "Z")
+        if data_interval_start is not None and data_interval_end is not None:
+            expected_data_interval_start = 
data_interval_start.replace("+00:00", "Z")
+            expected_data_interval_end = data_interval_end.replace("+00:00", 
"Z")
+
+        expected_response_json = {
+            "conf": {},
+            "dag_id": DAG1_ID,
+            "dag_run_id": expected_dag_run_id,
+            "end_date": None,
+            "logical_date": fixed_now.replace("+00:00", "Z"),
+            "external_trigger": True,
+            "start_date": None,
+            "state": "queued",
+            "data_interval_end": expected_data_interval_end,
+            "data_interval_start": expected_data_interval_start,
+            "queued_at": fixed_now.replace("+00:00", "Z"),
+            "last_scheduling_decision": None,
+            "run_type": "manual",
+            "note": note,
+            "triggered_by": "rest_api",
+        }
+
+        assert response.json() == expected_response_json
+
+    @pytest.mark.parametrize(
+        "post_body, expected_detail",
+        [
+            # Uncomment these 2 test cases once 
https://github.com/apache/airflow/pull/44306 is merged
+            # (
+            #     {"executiondate": "2020-11-10T08:25:56Z"},
+            #     {
+            #         "detail": [
+            #             {
+            #                 "input": "2020-11-10T08:25:56Z",
+            #                 "loc": ["body", "executiondate"],
+            #                 "msg": "Extra inputs are not permitted",
+            #                 "type": "extra_forbidden",
+            #             }
+            #         ]
+            #     },
+            # ),
+            # (
+            #     {"logical_date": "2020-11-10T08:25:56"},
+            #     {
+            #         "detail": [
+            #             {
+            #                 "input": "2020-11-10T08:25:56",
+            #                 "loc": ["body", "logical_date"],
+            #                 "msg": "Extra inputs are not permitted",
+            #                 "type": "extra_forbidden",
+            #             }
+            #         ]
+            #     },
+            # ),
+            (
+                {"data_interval_start": "2020-11-10T08:25:56"},
+                {
+                    "detail": [
+                        {
+                            "input": "2020-11-10T08:25:56",
+                            "loc": ["body", "data_interval_start"],
+                            "msg": "Input should have timezone info",
+                            "type": "timezone_aware",
+                        }
+                    ]
+                },
+            ),
+            (
+                {"data_interval_end": "2020-11-10T08:25:56"},
+                {
+                    "detail": [
+                        {
+                            "input": "2020-11-10T08:25:56",
+                            "loc": ["body", "data_interval_end"],
+                            "msg": "Input should have timezone info",
+                            "type": "timezone_aware",
+                        }
+                    ]
+                },
+            ),
+            (
+                {"dag_run_id": 20},
+                {
+                    "detail": [
+                        {
+                            "input": 20,
+                            "loc": ["body", "dag_run_id"],
+                            "msg": "Input should be a valid string",
+                            "type": "string_type",
+                        }
+                    ]
+                },
+            ),
+            (
+                {"note": 20},
+                {
+                    "detail": [
+                        {
+                            "input": 20,
+                            "loc": ["body", "note"],
+                            "msg": "Input should be a valid string",
+                            "type": "string_type",
+                        }
+                    ]
+                },
+            ),
+            (
+                {"conf": 20},
+                {
+                    "detail": [
+                        {
+                            "input": 20,
+                            "loc": ["body", "conf"],
+                            "msg": "Input should be a valid dictionary",
+                            "type": "dict_type",
+                        }
+                    ]
+                },
+            ),
+        ],
+    )
+    def test_invalid_data(self, test_client, post_body, expected_detail):
+        response = test_client.post(f"/public/dags/{DAG1_ID}/dagRuns", 
json=post_body)
+        assert response.status_code == 422
+        assert response.json() == expected_detail
+
+    @mock.patch("airflow.models.DAG.create_dagrun")
+    def test_dagrun_creation_exception_is_handled(self, mock_create_dagrun, 
test_client):
+        error_message = "Encountered Error"
+
+        mock_create_dagrun.side_effect = ValueError(error_message)
+
+        response = test_client.post(f"/public/dags/{DAG1_ID}/dagRuns", json={})
+        assert response.status_code == 400
+        assert response.json() == {"detail": error_message}
+
+    def test_should_respond_404_if_a_dag_is_inactive(self, test_client, 
session):
+        self._dags_for_trigger_tests(session)
+        response = test_client.post("/public/dags/inactive/dagRuns", json={})
+        assert response.status_code == 404
+        assert response.json()["detail"] == "DAG with dag_id: 'inactive' not 
found"
+
+    def test_should_respond_400_if_a_dag_has_import_errors(self, test_client, 
session):
+        self._dags_for_trigger_tests(session)
+        response = test_client.post("/public/dags/import_errors/dagRuns", 
json={})
+        assert response.status_code == 400
+        assert (
+            response.json()["detail"]
+            == "DAG with dag_id: 'import_errors' has import errors and cannot 
be triggered"
+        )
+
+    @time_machine.travel(timezone.utcnow(), tick=False)
+    def test_should_response_200_for_duplicate_logical_date(self, test_client):
+        RUN_ID_1 = "random_1"
+        RUN_ID_2 = "random_2"
+        now = timezone.utcnow().isoformat().replace("+00:00", "Z")
+        note = "duplicate logical date test"
+        response_1 = test_client.post(
+            f"/public/dags/{DAG1_ID}/dagRuns",
+            json={"dag_run_id": RUN_ID_1, "note": note},
+        )
+        response_2 = test_client.post(
+            f"/public/dags/{DAG1_ID}/dagRuns",
+            json={"dag_run_id": RUN_ID_2, "note": note},
+        )
+
+        assert response_1.status_code == response_2.status_code == 200
+        body1 = response_1.json()
+        body2 = response_2.json()
+
+        for each_run_id, each_body in [(RUN_ID_1, body1), (RUN_ID_2, body2)]:
+            assert each_body == {
+                "dag_run_id": each_run_id,
+                "dag_id": DAG1_ID,
+                "logical_date": now,
+                "queued_at": now,
+                "start_date": None,
+                "end_date": None,
+                "data_interval_start": now,
+                "data_interval_end": now,
+                "last_scheduling_decision": None,
+                "run_type": "manual",
+                "state": "queued",
+                "external_trigger": True,
+                "triggered_by": "rest_api",
+                "conf": {},
+                "note": note,
+            }
+
+    @pytest.mark.parametrize(
+        "data_interval_start, data_interval_end",
+        [
+            (
+                LOGICAL_DATE1.isoformat(),
+                None,
+            ),
+            (
+                None,
+                LOGICAL_DATE1.isoformat(),
+            ),
+        ],
+    )
+    def test_should_response_422_for_missing_start_date_or_end_date(
+        self, test_client, data_interval_start, data_interval_end
+    ):
+        response = test_client.post(
+            f"/public/dags/{DAG1_ID}/dagRuns",
+            json={"data_interval_start": data_interval_start, 
"data_interval_end": data_interval_end},
+        )
+        assert response.status_code == 422
+        assert (
+            response.json()["detail"][0]["msg"]
+            == "Value error, Either both data_interval_start and 
data_interval_end must be provided or both must be None"
+        )
+
+    def test_raises_validation_error_for_invalid_params(self, test_client):
+        response = test_client.post(
+            f"/public/dags/{DAG2_ID}/dagRuns",
+            json={"conf": {"validated_number": 5000}},
+        )
+        assert response.status_code == 400
+        assert "Invalid input for param validated_number" in 
response.json()["detail"]
+
+    def test_response_404(self, test_client):
+        response = test_client.post("/public/dags/randoms/dagRuns", json={})
+        assert response.status_code == 404
+        assert response.json()["detail"] == "DAG with dag_id: 'randoms' not 
found"
+
+    def test_response_409(self, test_client):
+        response = test_client.post(f"/public/dags/{DAG1_ID}/dagRuns", 
json={"dag_run_id": DAG1_RUN1_ID})
+        assert response.status_code == 409
+        assert response.json()["detail"] == "Unique constraint violation"

Reply via email to