pierrejeambrun commented on code in PR #43875:
URL: https://github.com/apache/airflow/pull/43875#discussion_r1856830840


##########
airflow/api_fastapi/core_api/routes/public/dag_run.py:
##########
@@ -296,3 +302,71 @@ def get_dag_runs(
         dag_runs=dag_runs,
         total_entries=total_entries,
     )
+
+
+@dag_run_router.post(
+    "",
+    responses=create_openapi_http_exception_doc(
+        [
+            status.HTTP_400_BAD_REQUEST,
+            status.HTTP_404_NOT_FOUND,
+            status.HTTP_409_CONFLICT,
+        ]
+    ),
+)
+def trigger_dag_run(
+    dag_id, body: TriggerDAGRunPostBody, request: Request, session: 
Annotated[Session, Depends(get_session)]
+) -> DAGRunResponse:
+    """Trigger a DAG."""
+    dm = session.scalar(select(DagModel).where(DagModel.is_active, 
DagModel.dag_id == dag_id).limit(1))
+    if not dm:
+        raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with dag_id: 
'{dag_id}' not found")
+
+    if dm.has_import_errors:
+        raise HTTPException(
+            status.HTTP_400_BAD_REQUEST,
+            f"DAG with dag_id: '{dag_id}' has import errors and cannot be 
triggered",
+        )
+
+    run_id = body.dag_run_id
+    logical_date = pendulum.instance(body.logical_date)
+    dagrun_instance = session.scalar(
+        select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == 
run_id).limit(1)
+    )
+
+    if not dagrun_instance:
+        try:
+            dag: DAG = request.app.state.dag_bag.get_dag(dag_id)
+
+            if body.data_interval_start and body.data_interval_end:
+                data_interval = DataInterval(
+                    start=pendulum.instance(body.data_interval_start),
+                    end=pendulum.instance(body.data_interval_end),
+                )
+            else:
+                data_interval = 
dag.timetable.infer_manual_data_interval(run_after=logical_date)
+            dag_version = DagVersion.get_latest_version(dag.dag_id)
+            dag_run = dag.create_dagrun(
+                run_type=DagRunType.MANUAL,
+                run_id=run_id,
+                logical_date=logical_date,
+                data_interval=data_interval,
+                state=DagRunState.QUEUED,
+                conf=body.conf,
+                external_trigger=True,
+                dag_version=dag_version,
+                session=session,
+                triggered_by=DagRunTriggeredByType.REST_API,
+            )
+            dag_run_note = body.note
+            if dag_run_note:
+                current_user_id = None  # refer to 
https://github.com/apache/airflow/issues/43534
+                dag_run.note = (dag_run_note, current_user_id)
+            return dag_run
+        except ValueError as e:
+            raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e))
+
+    raise HTTPException(
+        status.HTTP_409_CONFLICT,
+        f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{body.dag_run_id}' 
already exists",
+    )

Review Comment:
   DB duplicate entry exceptions are already handled by the application.
   
   You can just always try to execute this code, if the database crashes with 
duplicate entry, a nice 409 errors will automatically be returned.



##########
airflow/api_fastapi/core_api/datamodels/dag_run.py:
##########
@@ -73,3 +76,37 @@ class DAGRunCollectionResponse(BaseModel):
 
     dag_runs: list[DAGRunResponse]
     total_entries: int
+
+
+class TriggerDAGRunPostBody(BaseModel):
+    """Trigger DAG Run Serializer for POST body."""
+
+    dag_run_id: str | None = None
+    data_interval_start: AwareDatetime | None = None
+    data_interval_end: AwareDatetime | None = None
+
+    conf: dict = Field(default_factory=dict)
+    note: str | None = None
+
+    model_config = {"extra": "forbid"}
+
+    @model_validator(mode="after")
+    def check_data_intervals(cls, values):
+        if (values.data_interval_start is None) != (values.data_interval_end 
is None):
+            raise HTTPException(
+                status.HTTP_422_UNPROCESSABLE_ENTITY,
+                "Either both data_interval_start and data_interval_end must be 
provided or both must be None",
+            )

Review Comment:
   Pydantic is taking care of formatting correctly 422 errors. Those have a 
whole specific structure.
   
   You can just raise `ValueError` pydantic catches those.



##########
airflow/api_fastapi/core_api/routes/public/dag_run.py:
##########
@@ -296,3 +302,71 @@ def get_dag_runs(
         dag_runs=dag_runs,
         total_entries=total_entries,
     )
+
+
+@dag_run_router.post(
+    "",
+    responses=create_openapi_http_exception_doc(
+        [
+            status.HTTP_400_BAD_REQUEST,
+            status.HTTP_404_NOT_FOUND,
+            status.HTTP_409_CONFLICT,

Review Comment:
   This 409 needs to stay though to be reflected in the documentation. 



##########
tests/api_fastapi/core_api/routes/public/test_dag_run.py:
##########
@@ -652,3 +658,308 @@ def test_clear_dag_run_unprocessable_entity(self, 
test_client):
         body = response.json()
         assert body["detail"][0]["msg"] == "Field required"
         assert body["detail"][0]["loc"][0] == "body"
+
+
+# @pytest.mark.no_setup

Review Comment:
   To remove 



##########
tests/api_fastapi/core_api/routes/public/test_dag_run.py:
##########
@@ -652,3 +658,308 @@ def test_clear_dag_run_unprocessable_entity(self, 
test_client):
         body = response.json()
         assert body["detail"][0]["msg"] == "Field required"
         assert body["detail"][0]["loc"][0] == "body"
+
+
+# @pytest.mark.no_setup
+class TestTriggerDagRun:
+    def _dags_for_trigger_tests(self, session=None):
+        inactive_dag = DagModel(
+            dag_id="inactive",
+            fileloc="/tmp/dag_del_1.py",
+            timetable_summary="2 2 * * *",
+            is_active=False,
+            is_paused=True,
+            owners="test_owner,another_test_owner",
+            next_dagrun=datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
+        )
+
+        import_errors_dag = DagModel(
+            dag_id="import_errors",
+            fileloc="/tmp/dag_del_2.py",
+            timetable_summary="2 2 * * *",
+            is_active=True,
+            owners="test_owner,another_test_owner",
+            next_dagrun=datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
+        )
+        import_errors_dag.has_import_errors = True
+
+        session.add(inactive_dag)
+        session.add(import_errors_dag)
+        session.commit()
+
+    @time_machine.travel(timezone.utcnow(), tick=False)
+    @pytest.mark.parametrize(
+        "dag_run_id, note, data_interval_start, data_interval_end",
+        [
+            ("dag_run_5", "test-note", None, None),
+            (
+                "dag_run_6",
+                "test-note",
+                "2024-01-03T00:00:00+00:00",
+                "2024-01-04T05:00:00+00:00",
+            ),
+            (None, None, None, None),
+        ],
+    )
+    def test_should_respond_200(
+        self,
+        test_client,
+        dag_run_id,
+        note,
+        data_interval_start,
+        data_interval_end,
+    ):
+        fixed_now = timezone.utcnow().isoformat()
+
+        request_json = {"note": note}
+        if dag_run_id is not None:
+            request_json["dag_run_id"] = dag_run_id
+        if data_interval_start is not None:
+            request_json["data_interval_start"] = data_interval_start
+        if data_interval_end is not None:
+            request_json["data_interval_end"] = data_interval_end
+
+        response = test_client.post(
+            f"/public/dags/{DAG1_ID}/dagRuns",
+            json={
+                "dag_run_id": dag_run_id,
+                "note": note,
+                "data_interval_start": data_interval_start,
+                "data_interval_end": data_interval_end,
+            },

Review Comment:
   Maybe ?
   ```suggestion
               json=request_json
   ```



##########
airflow/api_fastapi/core_api/datamodels/dag_run.py:
##########
@@ -73,3 +76,37 @@ class DAGRunCollectionResponse(BaseModel):
 
     dag_runs: list[DAGRunResponse]
     total_entries: int
+
+
+class TriggerDAGRunPostBody(BaseModel):
+    """Trigger DAG Run Serializer for POST body."""
+
+    dag_run_id: str | None = None
+    data_interval_start: AwareDatetime | None = None
+    data_interval_end: AwareDatetime | None = None
+
+    conf: dict = Field(default_factory=dict)
+    note: str | None = None
+
+    model_config = {"extra": "forbid"}

Review Comment:
   Nice,
   
   Maybe remove it in this PR so there is no special case. #44306 will 
introduce it for all or none.



##########
tests/api_fastapi/core_api/routes/public/test_dag_run.py:
##########
@@ -652,3 +658,308 @@ def test_clear_dag_run_unprocessable_entity(self, 
test_client):
         body = response.json()
         assert body["detail"][0]["msg"] == "Field required"
         assert body["detail"][0]["loc"][0] == "body"
+
+
+# @pytest.mark.no_setup
+class TestTriggerDagRun:
+    def _dags_for_trigger_tests(self, session=None):
+        inactive_dag = DagModel(
+            dag_id="inactive",
+            fileloc="/tmp/dag_del_1.py",
+            timetable_summary="2 2 * * *",
+            is_active=False,
+            is_paused=True,
+            owners="test_owner,another_test_owner",
+            next_dagrun=datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
+        )
+
+        import_errors_dag = DagModel(
+            dag_id="import_errors",
+            fileloc="/tmp/dag_del_2.py",
+            timetable_summary="2 2 * * *",
+            is_active=True,
+            owners="test_owner,another_test_owner",
+            next_dagrun=datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
+        )
+        import_errors_dag.has_import_errors = True
+
+        session.add(inactive_dag)
+        session.add(import_errors_dag)
+        session.commit()
+
+    @time_machine.travel(timezone.utcnow(), tick=False)
+    @pytest.mark.parametrize(
+        "dag_run_id, note, data_interval_start, data_interval_end",
+        [
+            ("dag_run_5", "test-note", None, None),
+            (
+                "dag_run_6",
+                "test-note",
+                "2024-01-03T00:00:00+00:00",
+                "2024-01-04T05:00:00+00:00",
+            ),
+            (None, None, None, None),
+        ],
+    )
+    def test_should_respond_200(
+        self,
+        test_client,
+        dag_run_id,
+        note,
+        data_interval_start,
+        data_interval_end,
+    ):
+        fixed_now = timezone.utcnow().isoformat()
+
+        request_json = {"note": note}
+        if dag_run_id is not None:
+            request_json["dag_run_id"] = dag_run_id
+        if data_interval_start is not None:
+            request_json["data_interval_start"] = data_interval_start
+        if data_interval_end is not None:
+            request_json["data_interval_end"] = data_interval_end
+
+        response = test_client.post(
+            f"/public/dags/{DAG1_ID}/dagRuns",
+            json={
+                "dag_run_id": dag_run_id,
+                "note": note,
+                "data_interval_start": data_interval_start,
+                "data_interval_end": data_interval_end,
+            },
+        )
+        assert response.status_code == 200
+
+        if dag_run_id is None:
+            expected_dag_run_id = f"manual__{fixed_now}"
+        else:
+            expected_dag_run_id = dag_run_id
+
+        expected_data_interval_start = fixed_now.replace("+00:00", "Z")
+        expected_data_interval_end = fixed_now.replace("+00:00", "Z")
+        if data_interval_start is not None and data_interval_end is not None:
+            expected_data_interval_start = 
data_interval_start.replace("+00:00", "Z")
+            expected_data_interval_end = data_interval_end.replace("+00:00", 
"Z")
+
+        expected_response_json = {
+            "conf": {},
+            "dag_id": DAG1_ID,
+            "run_id": expected_dag_run_id,
+            "end_date": None,
+            "logical_date": fixed_now.replace("+00:00", "Z"),
+            "external_trigger": True,
+            "start_date": None,
+            "state": "queued",
+            "data_interval_end": expected_data_interval_end,
+            "data_interval_start": expected_data_interval_start,
+            "queued_at": fixed_now.replace("+00:00", "Z"),
+            "last_scheduling_decision": None,
+            "run_type": "manual",
+            "note": note,
+            "triggered_by": "rest_api",
+        }
+
+        assert response.json() == expected_response_json
+
+    @pytest.mark.parametrize(
+        "post_body, expected_detail",
+        [
+            (
+                {"executiondate": "2020-11-10T08:25:56Z"},
+                {
+                    "detail": [
+                        {
+                            "input": "2020-11-10T08:25:56Z",
+                            "loc": ["body", "executiondate"],
+                            "msg": "Extra inputs are not permitted",
+                            "type": "extra_forbidden",
+                        }
+                    ]
+                },
+            ),
+            (
+                {"logical_date": "2020-11-10T08:25:56"},
+                {
+                    "detail": [
+                        {
+                            "input": "2020-11-10T08:25:56",
+                            "loc": ["body", "logical_date"],
+                            "msg": "Extra inputs are not permitted",
+                            "type": "extra_forbidden",
+                        }
+                    ]
+                },
+            ),
+            (
+                {"data_interval_start": "2020-11-10T08:25:56"},
+                {
+                    "detail": [
+                        {
+                            "input": "2020-11-10T08:25:56",
+                            "loc": ["body", "data_interval_start"],
+                            "msg": "Input should have timezone info",
+                            "type": "timezone_aware",
+                        }
+                    ]
+                },
+            ),
+            (
+                {"data_interval_end": "2020-11-10T08:25:56"},
+                {
+                    "detail": [
+                        {
+                            "input": "2020-11-10T08:25:56",
+                            "loc": ["body", "data_interval_end"],
+                            "msg": "Input should have timezone info",
+                            "type": "timezone_aware",
+                        }
+                    ]
+                },
+            ),
+            (
+                {"dag_run_id": 20},
+                {
+                    "detail": [
+                        {
+                            "input": 20,
+                            "loc": ["body", "dag_run_id"],
+                            "msg": "Input should be a valid string",
+                            "type": "string_type",
+                        }
+                    ]
+                },
+            ),
+            (
+                {"note": 20},
+                {
+                    "detail": [
+                        {
+                            "input": 20,
+                            "loc": ["body", "note"],
+                            "msg": "Input should be a valid string",
+                            "type": "string_type",
+                        }
+                    ]
+                },
+            ),
+            (
+                {"conf": 20},
+                {
+                    "detail": [
+                        {
+                            "input": 20,
+                            "loc": ["body", "conf"],
+                            "msg": "Input should be a valid dictionary",
+                            "type": "dict_type",
+                        }
+                    ]
+                },
+            ),
+        ],
+    )
+    def test_invalid_data(self, test_client, post_body, expected_detail):
+        response = test_client.post(f"/public/dags/{DAG1_ID}/dagRuns", 
json=post_body)
+        assert response.status_code == 422
+        assert response.json() == expected_detail
+
+    @mock.patch("airflow.models.DAG.create_dagrun")
+    def test_dagrun_creation_exception_is_handled(self, mock_create_dagrun, 
test_client):
+        error_message = "Encountered Error"
+
+        mock_create_dagrun.side_effect = ValueError(error_message)
+
+        response = test_client.post(f"/public/dags/{DAG1_ID}/dagRuns", json={})
+        assert response.status_code == 400
+        assert response.json() == {"detail": error_message}
+
+    def test_should_respond_404_if_a_dag_is_inactive(self, test_client, 
session):
+        self._dags_for_trigger_tests(session)
+        response = test_client.post("/public/dags/inactive/dagRuns", json={})
+        assert response.status_code == 404
+        assert response.json()["detail"] == "DAG with dag_id: 'inactive' not 
found"
+
+    def test_should_respond_400_if_a_dag_has_import_errors(self, test_client, 
session):
+        self._dags_for_trigger_tests(session)
+        response = test_client.post("/public/dags/import_errors/dagRuns", 
json={})
+        assert response.status_code == 400
+        assert (
+            response.json()["detail"]
+            == "DAG with dag_id: 'import_errors' has import errors and cannot 
be triggered"
+        )
+
+    @time_machine.travel(timezone.utcnow(), tick=False)
+    def test_should_response_200_for_duplicate_logical_date(self, test_client):
+        RUN_ID_1 = "random_1"
+        RUN_ID_2 = "random_2"
+        now = timezone.utcnow().isoformat().replace("+00:00", "Z")
+        note = "duplicate logical date test"
+        response_1 = test_client.post(
+            f"/public/dags/{DAG1_ID}/dagRuns",
+            json={"dag_run_id": RUN_ID_1, "note": note},
+        )
+        response_2 = test_client.post(
+            f"/public/dags/{DAG1_ID}/dagRuns",
+            json={"dag_run_id": RUN_ID_2, "note": note},
+        )
+
+        assert response_1.status_code == response_2.status_code == 200
+        body1 = response_1.json()
+        body2 = response_2.json()
+
+        for each_run_id, each_body in [(RUN_ID_1, body1), (RUN_ID_2, body2)]:
+            assert each_body == {
+                "run_id": each_run_id,
+                "dag_id": DAG1_ID,
+                "logical_date": now,
+                "queued_at": now,
+                "start_date": None,
+                "end_date": None,
+                "data_interval_start": now,
+                "data_interval_end": now,
+                "last_scheduling_decision": None,
+                "run_type": "manual",
+                "state": "queued",
+                "external_trigger": True,
+                "triggered_by": "rest_api",
+                "conf": {},
+                "note": note,
+            }
+
+    @pytest.mark.parametrize(
+        "data_interval_start, data_interval_end",
+        [
+            (
+                LOGICAL_DATE1.isoformat(),
+                None,
+            ),
+            (
+                None,
+                LOGICAL_DATE1.isoformat(),
+            ),
+        ],
+    )
+    def test_should_response_422_for_missing_start_date_or_end_date(

Review Comment:
   I think we are missing the test where we try to create a dag_run with 
invalid param values.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to