amoghrajesh commented on code in PR #48651:
URL: https://github.com/apache/airflow/pull/48651#discussion_r2026468850


##########
airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -53,7 +55,9 @@
 from airflow.utils import timezone
 from airflow.utils.state import DagRunState, TaskInstanceState
 
-router = VersionedAPIRouter(
+router = VersionedAPIRouter()
+
+ti_id_router = VersionedAPIRouter(

Review Comment:
   Can we add a comment here explaining why we need this one? It will be clear 
to reader too



##########
providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py:
##########
@@ -81,42 +76,13 @@ def clean_db():
     clear_db_runs()
 
 
[email protected]
-def dag_zip_maker(testing_dag_bundle):
-    class DagZipMaker:
-        def __call__(self, *dag_files):
-            self.__dag_files = [os.sep.join([TEST_DAGS_FOLDER.__str__(), 
dag_file]) for dag_file in dag_files]
-            dag_files_hash = 
md5("".join(self.__dag_files).encode()).hexdigest()
-            self.__tmp_dir = os.sep.join([tempfile.tempdir, dag_files_hash])
-
-            self.__zip_file_name = os.sep.join([self.__tmp_dir, 
f"{dag_files_hash}.zip"])
-
-            if not os.path.exists(self.__tmp_dir):
-                os.mkdir(self.__tmp_dir)
-            return self
-
-        def __enter__(self):
-            with zipfile.ZipFile(self.__zip_file_name, "x") as zf:
-                for dag_file in self.__dag_files:
-                    zf.write(dag_file, os.path.basename(dag_file))
-            dagbag = DagBag(dag_folder=self.__tmp_dir, include_examples=False)
-            dagbag.sync_to_db("testing", None)
-            return dagbag
-
-        def __exit__(self, exc_type, exc_val, exc_tb):
-            os.unlink(self.__zip_file_name)
-            os.rmdir(self.__tmp_dir)
-
-    return DagZipMaker()
-
-
[email protected]("testing_dag_bundle")
-class TestExternalTaskSensor:
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for v3.0+")

Review Comment:
   Clever!



##########
airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py:
##########
@@ -218,3 +218,99 @@ def test_dag_run_not_found(self, client):
         response = client.post(f"/execution/dag-runs/{dag_id}/{run_id}/clear")
 
         assert response.status_code == 404
+
+
+class TestGetDagRunCount:
+    def setup_method(self):
+        clear_db_runs()
+
+    def teardown_method(self):
+        clear_db_runs()
+
+    def test_get_count_basic(self, client, session, dag_maker):
+        with dag_maker("test_dag"):
+            pass
+        dag_maker.create_dagrun()
+        session.commit()
+
+        response = client.get("/execution/dag-runs/count", params={"dag_id": 
"test_dag"})
+        assert response.status_code == 200
+        assert response.json() == 1
+
+    def test_get_count_with_states(self, client, session, dag_maker):
+        """Test counting DAG runs in specific states."""
+        with dag_maker("test_get_count_with_states"):
+            pass
+
+        # Create DAG runs with different states
+        dag_maker.create_dagrun(
+            state=State.SUCCESS, logical_date=timezone.datetime(2025, 1, 1), 
run_id="test_run_id1"
+        )
+        dag_maker.create_dagrun(
+            state=State.FAILED, logical_date=timezone.datetime(2025, 1, 2), 
run_id="test_run_id2"
+        )
+        dag_maker.create_dagrun(
+            state=State.RUNNING, logical_date=timezone.datetime(2025, 1, 3), 
run_id="test_run_id3"
+        )
+        session.commit()
+
+        response = client.get(
+            "/execution/dag-runs/count",
+            params={"dag_id": "test_get_count_with_states", "states": 
[State.SUCCESS, State.FAILED]},

Review Comment:
   Can we also add running here? Or maybe parameterise it?



##########
airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -564,8 +568,86 @@ def get_previous_successful_dagrun(
     return PrevSuccessfulDagRunResponse.model_validate(dag_run)
 
 
[email protected]_exists_in_older_versions
[email protected](
[email protected]("/count", status_code=status.HTTP_200_OK)

Review Comment:
   We will need to add cadwyn migration for the new endpoints: 
https://docs.cadwyn.dev/concepts/version_changes/



##########
task-sdk/src/airflow/sdk/api/client.py:
##########
@@ -200,6 +202,31 @@ def get_reschedule_start_date(self, id: uuid.UUID, 
try_number: int = 1) -> TaskR
         resp = self.client.get(f"task-reschedules/{id}/start_date", 
params={"try_number": try_number})
         return TaskRescheduleStartDate.model_construct(start_date=resp.json())
 
+    def get_count(
+        self,
+        dag_id: str,
+        task_ids: list[str] | None = None,
+        task_group_id: str | None = None,
+        logical_dates: list[datetime] | None = None,
+        run_ids: list[str] | None = None,
+        states: list[str] | None = None,
+    ) -> TICount:
+        """Get count of task instances matching the given criteria."""
+        params = {
+            "dag_id": dag_id,
+            "task_ids": task_ids,
+            "task_group_id": task_group_id,
+            "logical_dates": [d.isoformat() for d in logical_dates] if 
logical_dates is not None else None,
+            "run_ids": run_ids,
+            "states": states,
+        }
+
+        # Remove None values from params
+        params = {k: v for k, v in params.items() if v is not None}
+
+        resp = self.client.get("task-instances/count", params=params)
+        return TICount(count=resp.json())

Review Comment:
   Safeguard it inside a try/except in case of a 500?



##########
providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py:
##########
@@ -1584,7 +1856,10 @@ def dag_bag_head_tail():
         )
         head >> body >> tail
 
-    dag_bag.bag_dag(dag=dag)
+    if AIRFLOW_V_3_0_PLUS:
+        dag_bag.bag_dag(dag=dag)
+    else:
+        dag_bag.bag_dag(dag=dag, root_dag=dag)
 

Review Comment:
   Lets just define a mini utility function, too many usages here



##########
airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py:
##########
@@ -150,3 +152,27 @@ def get_dagrun_state(
         )
 
     return DagRunStateResponse(state=dag_run.state)
+
+
[email protected]("/count", status_code=status.HTTP_200_OK)
+def get_dr_count(
+    dag_id: str,
+    session: SessionDep,
+    logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None,
+    run_ids: Annotated[list[str] | None, Query()] = None,
+    states: Annotated[list[str] | None, Query()] = None,
+) -> int:
+    """Get the count of DAG runs matching the given criteria."""
+    query = select(func.count()).select_from(DagRun).where(DagRun.dag_id == 
dag_id)
+
+    if logical_dates:
+        query = query.where(DagRun.logical_date.in_(logical_dates))
+
+    if run_ids:
+        query = query.where(DagRun.run_id.in_(run_ids))
+
+    if states:
+        query = query.where(DagRun.state.in_(states))
+
+    count = session.scalar(query)
+    return count or 0

Review Comment:
   We will need to add cadwyn migration for the new endpoints: 
https://docs.cadwyn.dev/concepts/version_changes/



##########
airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py:
##########
@@ -1223,3 +1224,167 @@ def test_get_start_date_with_try_number(self, client, 
session, create_task_insta
         response = 
client.get(f"/execution/task-reschedules/{ti.id}/start_date?try_number=2")
         assert response.status_code == 200
         assert response.json() == "2024-01-02T00:00:00Z"
+
+
+class TestGetCount:

Review Comment:
   ```suggestion
   class TestGetTICount:
   ```



##########
task-sdk/src/airflow/sdk/api/client.py:
##########
@@ -452,6 +479,27 @@ def get_state(self, dag_id: str, run_id: str) -> 
DagRunStateResponse:
         resp = self.client.get(f"dag-runs/{dag_id}/{run_id}/state")
         return DagRunStateResponse.model_validate_json(resp.read())
 
+    def get_count(
+        self,
+        dag_id: str,
+        logical_dates: list[datetime] | None = None,
+        run_ids: list[str] | None = None,
+        states: list[str] | None = None,
+    ) -> DRCount:
+        """Get count of DAG runs matching the given criteria."""
+        params = {
+            "dag_id": dag_id,
+            "logical_dates": [d.isoformat() for d in logical_dates] if 
logical_dates is not None else None,
+            "run_ids": run_ids,
+            "states": states,
+        }
+
+        # Remove None values from params
+        params = {k: v for k, v in params.items() if v is not None}
+
+        resp = self.client.get("dag-runs/count", params=params)
+        return DRCount(count=resp.json())

Review Comment:
   Safeguard in a try/except for cases like 500?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to