This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-6-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 7f1039ed825771e2f15f7131e91013063bd7d0d7
Author: herlambang <[email protected]>
AuthorDate: Tue May 2 04:14:04 2023 +0700

    Fix unable to remove DagRun and TaskInstance with note (#30987)
    
    * Define cascade option on note relationship within DagRun and TaskInstance 
model
    
    ---------
    
    Co-authored-by: herlambang <[email protected]>
    (cherry picked from commit 0212b7c14c4ce6866d5da1ba9f25d3ecc5c2188f)
---
 airflow/models/dagrun.py          |  7 ++++++-
 airflow/models/taskinstance.py    |  8 ++++++--
 tests/models/test_dagrun.py       | 27 ++++++++++++++++++++++++++-
 tests/models/test_taskinstance.py | 30 +++++++++++++++++++++---------
 4 files changed, 59 insertions(+), 13 deletions(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index edb0ec78ac..ba0fd9fda3 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -173,7 +173,12 @@ class DagRun(Base, LoggingMixin):
         uselist=False,
         viewonly=True,
     )
-    dag_run_note = relationship("DagRunNote", back_populates="dag_run", 
uselist=False)
+    dag_run_note = relationship(
+        "DagRunNote",
+        back_populates="dag_run",
+        uselist=False,
+        cascade="all, delete, delete-orphan",
+    )
     note = association_proxy("dag_run_note", "content", creator=_creator_note)
 
     DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint(
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 49b3e7ebad..87c1699402 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -474,7 +474,12 @@ class TaskInstance(Base, LoggingMixin):
     dag_run = relationship("DagRun", back_populates="task_instances", 
lazy="joined", innerjoin=True)
     rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", 
lazy="noload", uselist=False)
     execution_date = association_proxy("dag_run", "execution_date")
-    task_instance_note = relationship("TaskInstanceNote", 
back_populates="task_instance", uselist=False)
+    task_instance_note = relationship(
+        "TaskInstanceNote",
+        back_populates="task_instance",
+        uselist=False,
+        cascade="all, delete, delete-orphan",
+    )
     note = association_proxy("task_instance_note", "content", 
creator=_creator_note)
     task: Operator  # Not always set...
 
@@ -1136,7 +1141,6 @@ class TaskInstance(Base, LoggingMixin):
         dep_context = dep_context or DepContext()
         for dep in dep_context.deps | self.task.deps:
             for dep_status in dep.get_dep_statuses(self, session, dep_context):
-
                 self.log.debug(
                     "%s dependency '%s' PASSED: %s, %s",
                     self,
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 8df2e9e0c1..ef92b28caa 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -39,6 +39,7 @@ from airflow.models import (
     clear_task_instances,
 )
 from airflow.models.baseoperator import BaseOperator
+from airflow.models.dagrun import DagRunNote
 from airflow.models.taskmap import TaskMap
 from airflow.operators.empty import EmptyOperator
 from airflow.operators.python import ShortCircuitOperator
@@ -1722,7 +1723,6 @@ def 
test_calls_to_verify_integrity_with_mapped_task_zero_length_at_runtime(dag_m
         session.merge(ti)
     session.flush()
     with caplog.at_level(logging.DEBUG):
-
         # Run verify_integrity as a whole and assert the tasks were removed
         dr.verify_integrity()
         tis = dr.get_task_instances()
@@ -2287,3 +2287,28 @@ def 
test_clearing_task_and_moving_from_non_mapped_to_mapped(dag_maker, session):
     dr1.task_instance_scheduling_decisions(session)
     for table in [TaskFail, XCom]:
         assert session.query(table).count() == 0
+
+
+def test_dagrun_with_note(dag_maker, session):
+    with dag_maker():
+
+        @task
+        def the_task():
+            print("Hi")
+
+        the_task()
+
+    dr: DagRun = dag_maker.create_dagrun()
+    dr.note = "dag run with note"
+
+    session.add(dr)
+    session.commit()
+
+    dr_note = session.query(DagRunNote).filter(DagRunNote.dag_run_id == 
dr.id).one()
+    assert dr_note.content == "dag run with note"
+
+    session.delete(dr)
+    session.commit()
+
+    assert session.query(DagRun).filter(DagRun.id == dr.id).one_or_none() is 
None
+    assert session.query(DagRunNote).filter(DagRunNote.dag_run_id == 
dr.id).one_or_none() is None
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index e9f95eb7c4..2fff467384 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -59,7 +59,7 @@ from airflow.models.pool import Pool
 from airflow.models.renderedtifields import RenderedTaskInstanceFields
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskfail import TaskFail
-from airflow.models.taskinstance import TaskInstance, TaskInstance as TI
+from airflow.models.taskinstance import TaskInstance, TaskInstance as TI, 
TaskInstanceNote
 from airflow.models.taskmap import TaskMap
 from airflow.models.taskreschedule import TaskReschedule
 from airflow.models.variable import Variable
@@ -347,7 +347,7 @@ class TestTaskInstance:
             assert ti.state == State.QUEUED
             dep_patch.return_value = TIDepStatus("mock_" + class_name, True, 
"mock")
 
-        for (dep_patch, method_patch) in patch_dict.values():
+        for dep_patch, method_patch in patch_dict.values():
             dep_patch.stop()
 
     def test_mark_non_runnable_task_as_success(self, create_task_instance):
@@ -846,7 +846,6 @@ class TestTaskInstance:
             return done
 
         with dag_maker(dag_id="test_reschedule_handling") as dag:
-
             task = PythonSensor.partial(
                 task_id="test_reschedule_handling_sensor",
                 mode="reschedule",
@@ -2039,7 +2038,6 @@ class TestTaskInstance:
 
     @pytest.mark.parametrize("schedule_interval, catchup", 
_prev_dates_param_list)
     def test_previous_ti(self, schedule_interval, catchup, dag_maker) -> None:
-
         scenario = [State.SUCCESS, State.FAILED, State.SUCCESS]
 
         ti_list = self._test_previous_dates_setup(schedule_interval, catchup, 
scenario, dag_maker)
@@ -2052,7 +2050,6 @@ class TestTaskInstance:
 
     @pytest.mark.parametrize("schedule_interval, catchup", 
_prev_dates_param_list)
     def test_previous_ti_success(self, schedule_interval, catchup, dag_maker) 
-> None:
-
         scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS]
 
         ti_list = self._test_previous_dates_setup(schedule_interval, catchup, 
scenario, dag_maker)
@@ -2066,7 +2063,6 @@ class TestTaskInstance:
 
     @pytest.mark.parametrize("schedule_interval, catchup", 
_prev_dates_param_list)
     def test_previous_execution_date_success(self, schedule_interval, catchup, 
dag_maker) -> None:
-
         scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS]
 
         ti_list = self._test_previous_dates_setup(schedule_interval, catchup, 
scenario, dag_maker)
@@ -2081,7 +2077,6 @@ class TestTaskInstance:
 
     @pytest.mark.parametrize("schedule_interval, catchup", 
_prev_dates_param_list)
     def test_previous_start_date_success(self, schedule_interval, catchup, 
dag_maker) -> None:
-
         scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS]
 
         ti_list = self._test_previous_dates_setup(schedule_interval, catchup, 
scenario, dag_maker)
@@ -2702,7 +2697,6 @@ class TestTaskInstance:
 
     @provide_session
     def test_get_rendered_template_fields(self, dag_maker, session=None):
-
         with dag_maker("test-dag", session=session) as dag:
             task = BashOperator(task_id="op1", bash_command="{{ task.task_id 
}}")
         dag.fileloc = TEST_DAGS_FOLDER / "test_get_k8s_pod_yaml.py"
@@ -2876,7 +2870,6 @@ class TestTaskInstance:
             ), f"Key: {key} had different values. Make sure it loads it in the 
refresh refresh_from_db()"
 
     def test_operator_field_with_serialization(self, create_task_instance):
-
         ti = create_task_instance()
         assert ti.task.task_type == "EmptyOperator"
         assert ti.task.operator_name == "EmptyOperator"
@@ -3997,3 +3990,22 @@ def 
test_mini_scheduler_not_skip_mapped_downstream_until_all_upstreams_finish(da
     generate_ti.schedule_downstream_tasks(session=session)
     # Now downstreams can be skipped.
     assert dr.get_task_instances([TaskInstanceState.SKIPPED], session=session)
+
+
+def test_taskinstance_with_note(create_task_instance, session):
+    ti: TaskInstance = create_task_instance(session=session)
+    ti.note = "ti with note"
+
+    session.add(ti)
+    session.commit()
+
+    filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id, 
run_id=ti.run_id, map_index=ti.map_index)
+
+    ti_note: TaskInstanceNote = 
session.query(TaskInstanceNote).filter_by(**filter_kwargs).one()
+    assert ti_note.content == "ti with note"
+
+    session.delete(ti)
+    session.commit()
+
+    assert 
session.query(TaskInstance).filter_by(**filter_kwargs).one_or_none() is None
+    assert 
session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None

Reply via email to