pierrejeambrun commented on code in PR #68702:
URL: https://github.com/apache/airflow/pull/68702#discussion_r3441051584


##########
airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py:
##########
@@ -241,3 +274,61 @@ class DAGRunsBatchBody(StrictBaseModel):
     duration_lt: float | None = None
 
     conf_contains: str | None = None
+
+
+class ClearPartitionsBody(StrictBaseModel):
+    """Request body for the clearPartitions endpoint (column-reset: set 
partition fields to None)."""
+
+    run_id: str | None = Field(
+        default=None,
+        description="Select runs by exact run_id. Mutually exclusive with 
``partition_key`` and partition date window.",
+    )
+    partition_key: str | None = Field(
+        default=None,
+        description="Select runs by exact partition key match. Mutually 
exclusive with ``run_id`` and partition date window.",
+    )
+    partition_date_start: datetime | None = Field(
+        default=None,
+        description="Inclusive start of the partition date window 
(calendar-day granular). Mutually exclusive with ``run_id`` and 
``partition_key``.",
+    )
+    partition_date_end: datetime | None = Field(
+        default=None,
+        description="Inclusive end of the partition date window (calendar-day 
granular). Mutually exclusive with ``run_id`` and ``partition_key``.",
+    )
+    clear_task_instances: bool = Field(
+        default=False,
+        description="Also clear task instances on the matched runs.",
+    )
+    dry_run: bool = Field(
+        default=True,
+        description="If True, compute counts without writing any changes.",
+    )
+
+    @model_validator(mode="after")
+    def validate_exactly_one_selector(self) -> ClearPartitionsBody:
+        has_run_id = self.run_id is not None
+        has_partition_key = self.partition_key is not None
+        has_partition_date_window = (
+            self.partition_date_start is not None or self.partition_date_end 
is not None
+        )
+        selectors_active = sum([has_run_id, has_partition_key, 
has_partition_date_window])
+        if selectors_active != 1:
+            raise ValueError(
+                "Exactly one of run_id, partition_key, or a partition date 
window "
+                "(partition_date_start / partition_date_end) must be provided."
+            )
+        if (
+            self.partition_date_start is not None
+            and self.partition_date_end is not None
+            and self.partition_date_start > self.partition_date_end
+        ):
+            raise ValueError("partition_date_start must be on or before 
partition_date_end.")
+        return self

Review Comment:
   Some fields and the validator are completely duplicated with 
`BulkDAGRunClearBody`. Can we factorize this with a base class ? 



##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py:
##########
@@ -150,6 +155,91 @@ def perform_clear_dag_run(
     return dag_run_cleared
 
 
+_TI_CHUNK_SIZE = 500
+
+
+def clear_partition_fields(
+    *,
+    dag: SerializedDAG,
+    body: ClearPartitionsBody,
+    dag_id: str,
+    session: Session,
+) -> tuple[int, int]:
+    """
+    Reset partition_key and partition_date to None on matching runs.
+
+    Returns (dag_runs_cleared, task_instances_cleared).
+    Mirrors ``airflow partitions clear`` column-reset behavior.
+    """
+    stmt = select(DagRun).where(DagRun.dag_id == dag_id)
+    if body.run_id is not None:
+        stmt = stmt.where(DagRun.run_id == body.run_id)
+    elif body.partition_key is not None:
+        stmt = stmt.where(DagRun.partition_key == body.partition_key)
+    else:
+        stmt = stmt.where(or_(DagRun.partition_key.is_not(None), 
DagRun.partition_date.is_not(None)))
+        if body.partition_date_start is not None:
+            lower = 
dag.timetable.resolve_day_bound(body.partition_date_start.date())
+            stmt = stmt.where(DagRun.partition_date >= lower)
+        if body.partition_date_end is not None:
+            upper = 
dag.timetable.resolve_day_bound(body.partition_date_end.date() + 
timedelta(days=1))
+            stmt = stmt.where(DagRun.partition_date < upper)

Review Comment:
   The partition-selector + resolve_day_bound date-window logic is now 
duplicated between clear_dag_runs (partition branch) and 
clear_partition_fields. They differ slightly (one selects DagRun.run_id, the 
other DagRun), but the window resolution is identical — worth extracting a 
small shared helper so the two can't drift?
   



##########
airflow-core/src/airflow/api_fastapi/core_api/security.py:
##########
@@ -823,6 +823,15 @@ def inner(
                 continue
             entity_methods.append((entity_dag_id, "PUT"))
 
+        partition_selectors_present = (
+            body.partition_key is not None
+            or body.partition_date_start is not None
+            or body.partition_date_end is not None
+        )
+        if not body.dag_runs and partition_selectors_present:
+            if dag_id and dag_id != "~":
+                entity_methods.append((dag_id, "PUT"))
+

Review Comment:
   Why is that necessary? It doesn't appear like this code path is hit in the 
code mentioned 



##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py:
##########
@@ -150,6 +155,91 @@ def perform_clear_dag_run(
     return dag_run_cleared
 
 
+_TI_CHUNK_SIZE = 500
+
+
+def clear_partition_fields(
+    *,
+    dag: SerializedDAG,
+    body: ClearPartitionsBody,
+    dag_id: str,
+    session: Session,
+) -> tuple[int, int]:
+    """
+    Reset partition_key and partition_date to None on matching runs.
+
+    Returns (dag_runs_cleared, task_instances_cleared).
+    Mirrors ``airflow partitions clear`` column-reset behavior.
+    """
+    stmt = select(DagRun).where(DagRun.dag_id == dag_id)
+    if body.run_id is not None:
+        stmt = stmt.where(DagRun.run_id == body.run_id)
+    elif body.partition_key is not None:
+        stmt = stmt.where(DagRun.partition_key == body.partition_key)
+    else:
+        stmt = stmt.where(or_(DagRun.partition_key.is_not(None), 
DagRun.partition_date.is_not(None)))
+        if body.partition_date_start is not None:
+            lower = 
dag.timetable.resolve_day_bound(body.partition_date_start.date())
+            stmt = stmt.where(DagRun.partition_date >= lower)
+        if body.partition_date_end is not None:
+            upper = 
dag.timetable.resolve_day_bound(body.partition_date_end.date() + 
timedelta(days=1))
+            stmt = stmt.where(DagRun.partition_date < upper)
+    stmt = stmt.order_by(DagRun.partition_date, DagRun.run_id)
+
+    dag_runs_cleared = 0
+    # Buffers for batched TI fetching — mirrors _flush_buffer in 
partition_command.py
+    ti_buffer_run_ids: list[str] = []
+    ti_carry: list[TaskInstance] = []
+    tis_cleared_total = 0
+
+    def _flush_ti_buffer(*, drain: bool = False) -> int:
+        flushed = 0
+        if ti_buffer_run_ids:
+            chunk_tis = list(
+                
session.scalars(select(TaskInstance).where(TaskInstance.run_id.in_(ti_buffer_run_ids)))
+            )
+            ti_buffer_run_ids.clear()
+            ti_carry.extend(chunk_tis)
+        while len(ti_carry) >= _TI_CHUNK_SIZE:
+            slice_tis = ti_carry[:_TI_CHUNK_SIZE]
+            del ti_carry[:_TI_CHUNK_SIZE]
+            clear_task_instances(slice_tis, session=session)
+            flushed += len(slice_tis)
+        if drain and ti_carry:
+            clear_task_instances(ti_carry, session=session)
+            flushed += len(ti_carry)
+        return flushed
+
+    # For dry_run TI count
+    tis_dry_total = 0
+
+    for run in session.scalars(stmt).yield_per(100):
+        fields_already_cleared = run.partition_key is None and 
run.partition_date is None
+        if fields_already_cleared and not body.clear_task_instances:
+            continue
+        if not fields_already_cleared:
+            if not body.dry_run:
+                run.partition_key = None
+                run.partition_date = None
+            dag_runs_cleared += 1
+        if body.clear_task_instances:
+            if body.dry_run:
+                run_tis = 
list(session.scalars(select(TaskInstance).where(TaskInstance.run_id == 
run.run_id)))
+                tis_dry_total += len(run_tis)

Review Comment:
   This scales the number of query with the number of `run` (the size of the 
partition).
   
   Is is possible to replace this with one call instead?



##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py:
##########
@@ -150,6 +155,91 @@ def perform_clear_dag_run(
     return dag_run_cleared
 
 
+_TI_CHUNK_SIZE = 500
+
+
+def clear_partition_fields(
+    *,
+    dag: SerializedDAG,
+    body: ClearPartitionsBody,
+    dag_id: str,
+    session: Session,
+) -> tuple[int, int]:
+    """
+    Reset partition_key and partition_date to None on matching runs.
+
+    Returns (dag_runs_cleared, task_instances_cleared).
+    Mirrors ``airflow partitions clear`` column-reset behavior.
+    """
+    stmt = select(DagRun).where(DagRun.dag_id == dag_id)
+    if body.run_id is not None:
+        stmt = stmt.where(DagRun.run_id == body.run_id)
+    elif body.partition_key is not None:
+        stmt = stmt.where(DagRun.partition_key == body.partition_key)
+    else:
+        stmt = stmt.where(or_(DagRun.partition_key.is_not(None), 
DagRun.partition_date.is_not(None)))
+        if body.partition_date_start is not None:
+            lower = 
dag.timetable.resolve_day_bound(body.partition_date_start.date())
+            stmt = stmt.where(DagRun.partition_date >= lower)
+        if body.partition_date_end is not None:
+            upper = 
dag.timetable.resolve_day_bound(body.partition_date_end.date() + 
timedelta(days=1))
+            stmt = stmt.where(DagRun.partition_date < upper)
+    stmt = stmt.order_by(DagRun.partition_date, DagRun.run_id)
+
+    dag_runs_cleared = 0
+    # Buffers for batched TI fetching — mirrors _flush_buffer in 
partition_command.py
+    ti_buffer_run_ids: list[str] = []
+    ti_carry: list[TaskInstance] = []
+    tis_cleared_total = 0
+
+    def _flush_ti_buffer(*, drain: bool = False) -> int:
+        flushed = 0
+        if ti_buffer_run_ids:
+            chunk_tis = list(
+                
session.scalars(select(TaskInstance).where(TaskInstance.run_id.in_(ti_buffer_run_ids)))

Review Comment:
   Probably add a safeguard on the `dag_id` here too. because `run_id` is 
unique per dag id, not cross dags.



##########
airflow-core/src/airflow/api_fastapi/core_api/security.py:
##########
@@ -823,6 +823,15 @@ def inner(
                 continue
             entity_methods.append((entity_dag_id, "PUT"))
 
+        partition_selectors_present = (
+            body.partition_key is not None
+            or body.partition_date_start is not None
+            or body.partition_date_end is not None
+        )

Review Comment:
   Extract this into a BulkDAGRunClearBody method / computed field to avoid 
repeating this here and in `partition_mode` above.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to