pierrejeambrun commented on code in PR #61058: URL: https://github.com/apache/airflow/pull/61058#discussion_r2746909295
########## airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_gantt.py: ########## @@ -0,0 +1,310 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from operator import attrgetter + +import pendulum +import pytest + +from airflow._shared.timezones import timezone +from airflow.models.dagbag import DBDagBag +from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.utils.session import provide_session +from airflow.utils.state import DagRunState, TaskInstanceState +from airflow.utils.types import DagRunTriggeredByType, DagRunType + +from tests_common.test_utils.asserts import assert_queries_count +from tests_common.test_utils.db import clear_db_assets, clear_db_dags, clear_db_runs, clear_db_serialized_dags +from tests_common.test_utils.mock_operators import MockOperator + +pytestmark = pytest.mark.db_test + +DAG_ID = "test_gantt_dag" +DAG_ID_2 = "test_gantt_dag_2" +DAG_ID_3 = "test_gantt_dag_3" +TASK_ID = "task" +TASK_ID_2 = "task2" +TASK_ID_3 = "task3" +MAPPED_TASK_ID = "mapped_task" + +GANTT_TASK_1 = { + "task_id": "task", + "try_number": 1, + "state": "success", + "start_date": "2024-11-30T10:00:00Z", + "end_date": "2024-11-30T10:05:00Z", + "is_group": False, + "is_mapped": False, +} + +GANTT_TASK_2 = { + "task_id": "task2", + "try_number": 1, + "state": "failed", + "start_date": "2024-11-30T10:05:00Z", + "end_date": "2024-11-30T10:10:00Z", + "is_group": False, + "is_mapped": False, +} + +GANTT_TASK_3 = { + "task_id": "task3", + "try_number": 1, + "state": "running", + "start_date": "2024-11-30T10:10:00Z", + "end_date": None, + "is_group": False, + "is_mapped": False, +} + + [email protected](autouse=True, scope="module") +def examples_dag_bag(): + return DBDagBag() + + [email protected](autouse=True) +@provide_session +def setup(dag_maker, session=None): + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} + + # DAG 1: Multiple tasks with different states (success, failed, running) + with dag_maker(dag_id=DAG_ID, serialized=True, session=session) as dag: + EmptyOperator(task_id=TASK_ID) + EmptyOperator(task_id=TASK_ID_2) + EmptyOperator(task_id=TASK_ID_3) + + logical_date = timezone.datetime(2024, 11, 30) + data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date) + + run_1 = dag_maker.create_dagrun( + run_id="run_1", + state=DagRunState.RUNNING, + run_type=DagRunType.MANUAL, + logical_date=logical_date, + data_interval=data_interval, + **triggered_by_kwargs, + ) + + for ti in sorted(run_1.task_instances, key=attrgetter("task_id")): + if ti.task_id == TASK_ID: + ti.state = TaskInstanceState.SUCCESS + ti.try_number = 1 + ti.start_date = pendulum.DateTime(2024, 11, 30, 10, 0, 0, tzinfo=pendulum.UTC) + ti.end_date = pendulum.DateTime(2024, 11, 30, 10, 5, 0, tzinfo=pendulum.UTC) + elif ti.task_id == TASK_ID_2: + ti.state = TaskInstanceState.FAILED + ti.try_number = 1 + ti.start_date = pendulum.DateTime(2024, 11, 30, 10, 5, 0, tzinfo=pendulum.UTC) + ti.end_date = pendulum.DateTime(2024, 11, 30, 10, 10, 0, tzinfo=pendulum.UTC) + elif ti.task_id == TASK_ID_3: + ti.state = TaskInstanceState.RUNNING + ti.try_number = 1 + ti.start_date = pendulum.DateTime(2024, 11, 30, 10, 10, 0, tzinfo=pendulum.UTC) + ti.end_date = None + + # DAG 2: With mapped tasks (only non-mapped should be returned) + with dag_maker(dag_id=DAG_ID_2, serialized=True, session=session) as dag_2: + EmptyOperator(task_id=TASK_ID) + MockOperator.partial(task_id=MAPPED_TASK_ID).expand(arg1=["a", "b", "c"]) + + logical_date_2 = timezone.datetime(2024, 12, 1) + data_interval_2 = dag_2.timetable.infer_manual_data_interval(run_after=logical_date_2) + + run_2 = dag_maker.create_dagrun( + run_id="run_2", + state=DagRunState.SUCCESS, + run_type=DagRunType.MANUAL, + logical_date=logical_date_2, + data_interval=data_interval_2, + **triggered_by_kwargs, + ) + + for ti in run_2.task_instances: + ti.state = TaskInstanceState.SUCCESS + ti.try_number = 1 + ti.start_date = pendulum.DateTime(2024, 12, 1, 10, 0, 0, tzinfo=pendulum.UTC) + ti.end_date = pendulum.DateTime(2024, 12, 1, 10, 5, 0, tzinfo=pendulum.UTC) + + # DAG 3: With UP_FOR_RETRY state (should be excluded from results) + with dag_maker(dag_id=DAG_ID_3, serialized=True, session=session) as dag_3: + EmptyOperator(task_id=TASK_ID) + EmptyOperator(task_id=TASK_ID_2) + + logical_date_3 = timezone.datetime(2024, 12, 2) + data_interval_3 = dag_3.timetable.infer_manual_data_interval(run_after=logical_date_3) + + run_3 = dag_maker.create_dagrun( + run_id="run_3", + state=DagRunState.RUNNING, + run_type=DagRunType.MANUAL, + logical_date=logical_date_3, + data_interval=data_interval_3, + **triggered_by_kwargs, + ) + + for ti in sorted(run_3.task_instances, key=attrgetter("task_id")): + if ti.task_id == TASK_ID: + ti.state = TaskInstanceState.SUCCESS + ti.try_number = 1 + ti.start_date = pendulum.DateTime(2024, 12, 2, 10, 0, 0, tzinfo=pendulum.UTC) + ti.end_date = pendulum.DateTime(2024, 12, 2, 10, 5, 0, tzinfo=pendulum.UTC) + elif ti.task_id == TASK_ID_2: + # UP_FOR_RETRY should be excluded (historical tries are in TaskInstanceHistory) + ti.state = TaskInstanceState.UP_FOR_RETRY + ti.try_number = 2 + ti.start_date = pendulum.DateTime(2024, 12, 2, 10, 5, 0, tzinfo=pendulum.UTC) + ti.end_date = pendulum.DateTime(2024, 12, 2, 10, 10, 0, tzinfo=pendulum.UTC) + + session.commit() + + [email protected](autouse=True) +def _clean(): + clear_db_runs() + clear_db_assets() + yield + clear_db_runs() + clear_db_assets() + + [email protected]("setup") +class TestGetGanttDataEndpoint: + def test_should_response_200(self, test_client): + with assert_queries_count(3): + response = test_client.get(f"/gantt/{DAG_ID}/run_1") + assert response.status_code == 200 + data = response.json() + assert data["dag_id"] == DAG_ID + assert data["run_id"] == "run_1" + actual = sorted(data["task_instances"], key=lambda x: x["task_id"]) + assert actual == [GANTT_TASK_1, GANTT_TASK_2, GANTT_TASK_3] + + @pytest.mark.parametrize( + ("dag_id", "run_id", "expected_task_ids", "expected_states"), + [ + pytest.param( + DAG_ID, + "run_1", + ["task", "task2", "task3"], + {"success", "failed", "running"}, + id="dag1_multiple_states", + ), + pytest.param( + DAG_ID_2, + "run_2", + ["task"], + {"success"}, + id="dag2_filters_mapped_tasks", + ), + pytest.param( + DAG_ID_3, + "run_3", + ["task"], + {"success"}, + id="dag3_excludes_up_for_retry", + ), + ], + ) + def test_task_filtering_and_states(self, test_client, dag_id, run_id, expected_task_ids, expected_states): + response = test_client.get(f"/gantt/{dag_id}/{run_id}") + assert response.status_code == 200 + data = response.json() + + actual_task_ids = sorted([ti["task_id"] for ti in data["task_instances"]]) + assert actual_task_ids == expected_task_ids + + actual_states = {ti["state"] for ti in data["task_instances"]} + assert actual_states == expected_states + + @pytest.mark.parametrize( + ("dag_id", "run_id", "task_id", "expected_start", "expected_end", "expected_state"), + [ + pytest.param( + DAG_ID, + "run_1", + "task", + "2024-11-30T10:00:00Z", + "2024-11-30T10:05:00Z", + "success", + id="success_task_has_dates", + ), + pytest.param( + DAG_ID, + "run_1", Review Comment: Can you add db queries guards, just to ensure that this doesn't explode later when updating that endpoint. (and to make sure there isn't N+1 db queries problem, there doesn't seem to be btw) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
