This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7092cfdbbf Fix bad delete logic for dagruns (#32684)
7092cfdbbf is described below

commit 7092cfdbbfcfd3c03909229daa741a5bcd7ccc64
Author: Daniel Standish <[email protected]>
AuthorDate: Wed Jul 19 13:27:56 2023 -0700

    Fix bad delete logic for dagruns (#32684)
    
    Co-authored-by: Jed Cunningham 
<[email protected]>
---
 airflow/www/utils.py    | 12 ++++++++++--
 tests/www/test_utils.py | 47 +++++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 55 insertions(+), 4 deletions(-)

diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 3624fbe841..bcf368ea20 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -50,6 +50,7 @@ from airflow.utils import timezone
 from airflow.utils.code_utils import get_python_source
 from airflow.utils.helpers import alchemy_to_dict
 from airflow.utils.json import WebEncoder
+from airflow.utils.sqlalchemy import tuple_in_condition
 from airflow.utils.state import State, TaskInstanceState
 from airflow.www.forms import DateTimeWithTimezoneField
 from airflow.www.widgets import AirflowDateTimePickerWidget
@@ -60,6 +61,8 @@ if TYPE_CHECKING:
 
     from airflow.www.fab_security.sqla.manager import SecurityManager
 
+TI = TaskInstance
+
 
 def datetime_to_string(value: DateTime | None) -> str | None:
     if value is None:
@@ -844,12 +847,17 @@ class DagRunCustomSQLAInterface(CustomSQLAInterface):
     """
 
     def delete(self, item: Model, raise_exception: bool = False) -> bool:
-        self.session.execute(delete(TaskInstance).where(TaskInstance.run_id == 
item.run_id))
+        self.session.execute(delete(TI).where(TI.dag_id == item.dag_id, 
TI.run_id == item.run_id))
         return super().delete(item, raise_exception=raise_exception)
 
     def delete_all(self, items: list[Model]) -> bool:
         self.session.execute(
-            delete(TaskInstance).where(TaskInstance.run_id.in_(item.run_id for 
item in items))
+            delete(TI).where(
+                tuple_in_condition(
+                    (TI.dag_id, TI.run_id),
+                    ((x.dag_id, x.run_id) for x in items),
+                )
+            )
         )
         return super().delete_all(items)
 
diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py
index 12fe017c62..1dd1665a9b 100644
--- a/tests/www/test_utils.py
+++ b/tests/www/test_utils.py
@@ -17,17 +17,20 @@
 # under the License.
 from __future__ import annotations
 
+import itertools
 import re
 from datetime import datetime
 from unittest.mock import Mock
 from urllib.parse import parse_qs
 
+import pendulum
 from bs4 import BeautifulSoup
 from markupsafe import Markup
 
+from airflow.models import DagRun
 from airflow.utils import json as utils_json
 from airflow.www import utils
-from airflow.www.utils import json_f, wrapped_markdown
+from airflow.www.utils import DagRunCustomSQLAInterface, json_f, 
wrapped_markdown
 
 
 class TestUtils:
@@ -156,7 +159,6 @@ class TestUtils:
         assert "<script>alert(1)</script>" not in html
 
     def test_task_instance_link(self):
-
         from airflow.www.app import cached_app
 
         with cached_app(testing=True).test_request_context():
@@ -413,3 +415,44 @@ class TestWrappedMarkdown:
 </div>"""
             == rendered
         )
+
+
+def test_dag_run_custom_sqla_interface_delete_no_collateral_damage(dag_maker, 
session):
+    interface = DagRunCustomSQLAInterface(obj=DagRun, session=session)
+    dag_ids = (f"test_dag_{x}" for x in range(1, 4))
+    dates = (pendulum.datetime(2023, 1, x) for x in range(1, 4))
+    for dag_id, date in itertools.product(dag_ids, dates):
+        with dag_maker(dag_id=dag_id) as dag:
+            dag.create_dagrun(execution_date=date, state="running", 
run_type="scheduled")
+    dag_runs = session.query(DagRun).all()
+    assert len(dag_runs) == 9
+    assert len(set(x.run_id for x in dag_runs)) == 3
+    run_id_for_single_delete = "scheduled__2023-01-01T00:00:00+00:00"
+    # we have 3 runs with this same run_id
+    assert len(list(x for x in dag_runs if x.run_id == 
run_id_for_single_delete)) == 3
+    # each is a different dag
+
+    # if we delete one, it shouldn't delete the others
+    one_run = [x for x in dag_runs if x.run_id == run_id_for_single_delete][0]
+    assert interface.delete(item=one_run) is True
+    session.commit()
+    dag_runs = session.query(DagRun).all()
+    # we should have one fewer dag run now
+    assert len(dag_runs) == 8
+
+    # now let's try multi delete
+    run_id_for_multi_delete = "scheduled__2023-01-02T00:00:00+00:00"
+    # verify we have 3
+    runs_of_interest = [x for x in dag_runs if x.run_id == 
run_id_for_multi_delete]
+    assert len(runs_of_interest) == 3
+    # and that each is different dag
+    assert len(set(x.dag_id for x in dag_runs)) == 3
+
+    to_delete = runs_of_interest[:2]
+    # now try multi delete
+    assert interface.delete_all(items=to_delete) is True
+    session.commit()
+    dag_runs = session.query(DagRun).all()
+    assert len(dag_runs) == 6
+    assert len(set(x.dag_id for x in dag_runs)) == 3
+    assert len(set(x.run_id for x in dag_runs)) == 3

Reply via email to