This is an automated email from the ASF dual-hosted git repository.

milton0825 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new b01d95e  Change DAG.clear to take dag_run_state (#9824)
b01d95e is described below

commit b01d95ec22b01ed79123178acd74ef40d57aaa7c
Author: Chao-Han Tsai <[email protected]>
AuthorDate: Wed Jul 15 13:08:18 2020 -0700

    Change DAG.clear to take dag_run_state (#9824)
    
    * Change DAG.clear to take dag_run_state
    
    * fix lint
    
    * fix tests
    
    * assign var
    
    * extend original clause
---
 airflow/cli/commands/dag_command.py                |  4 +-
 airflow/models/dag.py                              | 55 ++++++++++----------
 .../cloud/example_dags/example_datafusion.py       |  3 +-
 .../google/cloud/example_dags/example_gcs.py       |  3 +-
 .../example_dags/example_campaign_manager.py       |  3 +-
 tests/cli/commands/test_dag_command.py             |  7 ++-
 tests/models/test_dag.py                           | 60 +++++++++++++++++++---
 7 files changed, 94 insertions(+), 41 deletions(-)

diff --git a/airflow/cli/commands/dag_command.py 
b/airflow/cli/commands/dag_command.py
index 685ca21..55c40b4 100644
--- a/airflow/cli/commands/dag_command.py
+++ b/airflow/cli/commands/dag_command.py
@@ -39,6 +39,7 @@ from airflow.utils import cli as cli_utils
 from airflow.utils.cli import get_dag, get_dag_by_file_location, 
process_subdir, sigint_handler
 from airflow.utils.dot_renderer import render_dag
 from airflow.utils.session import create_session, provide_session
+from airflow.utils.state import State
 
 
 def _tabulate_dag_runs(dag_runs: List[DagRun], tablefmt: str = "fancy_grid") 
-> str:
@@ -123,6 +124,7 @@ def dag_backfill(args, dag=None):
                 end_date=args.end_date,
                 confirm_prompt=not args.yes,
                 include_subdags=True,
+                dag_run_state=State.NONE,
             )
 
         dag.run(
@@ -381,7 +383,7 @@ def dag_list_dag_runs(args, dag=None):
 def dag_test(args, session=None):
     """Execute one single DagRun for a given DAG and execution date, using the 
DebugExecutor."""
     dag = get_dag(subdir=args.subdir, dag_id=args.dag_id)
-    dag.clear(start_date=args.execution_date, end_date=args.execution_date, 
reset_dag_runs=True)
+    dag.clear(start_date=args.execution_date, end_date=args.execution_date, 
dag_run_state=State.NONE)
     try:
         dag.run(executor=DebugExecutor(), start_date=args.execution_date, 
end_date=args.execution_date)
     except BackfillUnfinished as e:
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index e6aafd3..dfb6409 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -27,7 +27,7 @@ import traceback
 import warnings
 from collections import OrderedDict
 from datetime import datetime, timedelta
-from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, 
Optional, Set, Type, Union
+from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, 
Optional, Set, Type, Union, cast
 
 import jinja2
 import pendulum
@@ -297,7 +297,7 @@ class DAG(BaseDag, LoggingMixin):
             template_searchpath = [template_searchpath]
         self.template_searchpath = template_searchpath
         self.template_undefined = template_undefined
-        self.parent_dag = None  # Gets set when DAGs are loaded
+        self.parent_dag: Optional[DAG] = None  # Gets set when DAGs are loaded
         self.last_loaded = timezone.utcnow()
         self.safe_dag_id = dag_id.replace('.', '__dot__')
         self.max_active_runs = max_active_runs
@@ -966,7 +966,7 @@ class DAG(BaseDag, LoggingMixin):
             confirm_prompt=False,
             include_subdags=True,
             include_parentdag=True,
-            reset_dag_runs=True,
+            dag_run_state: str = State.RUNNING,
             dry_run=False,
             session=None,
             get_tis=False,
@@ -993,8 +993,7 @@ class DAG(BaseDag, LoggingMixin):
         :type include_subdags: bool
         :param include_parentdag: Clear tasks in the parent dag of the subdag.
         :type include_parentdag: bool
-        :param reset_dag_runs: Set state of dag to RUNNING
-        :type reset_dag_runs: bool
+        :param dag_run_state: state to set DagRun to
         :param dry_run: Find the tasks to clear but don't clear them.
         :type dry_run: bool
         :param session: The sqlalchemy session to use
@@ -1025,8 +1024,7 @@ class DAG(BaseDag, LoggingMixin):
             tis = session.query(TI).filter(TI.dag_id == self.dag_id)
             tis = tis.filter(TI.task_id.in_(self.task_ids))
 
-        if include_parentdag and self.is_subdag:
-
+        if include_parentdag and self.is_subdag and self.parent_dag is not 
None:
             p_dag = self.parent_dag.sub_dag(
                 task_regex=r"^{}$".format(self.dag_id.split('.')[1]),
                 include_upstream=False,
@@ -1039,7 +1037,7 @@ class DAG(BaseDag, LoggingMixin):
                 confirm_prompt=confirm_prompt,
                 include_subdags=include_subdags,
                 include_parentdag=False,
-                reset_dag_runs=reset_dag_runs,
+                dag_run_state=dag_run_state,
                 get_tis=True,
                 session=session,
                 recursion_depth=recursion_depth,
@@ -1065,12 +1063,13 @@ class DAG(BaseDag, LoggingMixin):
             instances = tis.all()
             for ti in instances:
                 if ti.operator == ExternalTaskMarker.__name__:
-                    ti.task = self.get_task(ti.task_id)
+                    task: ExternalTaskMarker = cast(ExternalTaskMarker, 
self.get_task(ti.task_id))
+                    ti.task = task
 
                     if recursion_depth == 0:
                         # Maximum recursion depth allowed is the 
recursion_depth of the first
                         # ExternalTaskMarker in the tasks to be cleared.
-                        max_recursion_depth = ti.task.recursion_depth
+                        max_recursion_depth = task.recursion_depth
 
                     if recursion_depth + 1 > max_recursion_depth:
                         # Prevent cycles or accidents.
@@ -1080,10 +1079,10 @@ class DAG(BaseDag, LoggingMixin):
                                                .format(max_recursion_depth,
                                                        
ExternalTaskMarker.__name__, ti.task_id))
                     ti.render_templates()
-                    external_tis = session.query(TI).filter(TI.dag_id == 
ti.task.external_dag_id,
-                                                            TI.task_id == 
ti.task.external_task_id,
+                    external_tis = session.query(TI).filter(TI.dag_id == 
task.external_dag_id,
+                                                            TI.task_id == 
task.external_task_id,
                                                             TI.execution_date 
==
-                                                            
pendulum.parse(ti.task.execution_date))
+                                                            
pendulum.parse(task.execution_date))
 
                     for tii in external_tis:
                         if not dag_bag:
@@ -1103,7 +1102,7 @@ class DAG(BaseDag, LoggingMixin):
                                                          
confirm_prompt=confirm_prompt,
                                                          
include_subdags=include_subdags,
                                                          
include_parentdag=False,
-                                                         
reset_dag_runs=reset_dag_runs,
+                                                         
dag_run_state=dag_run_state,
                                                          get_tis=True,
                                                          session=session,
                                                          
recursion_depth=recursion_depth + 1,
@@ -1134,16 +1133,18 @@ class DAG(BaseDag, LoggingMixin):
             do_it = utils.helpers.ask_yesno(question)
 
         if do_it:
-            clear_task_instances(tis,
-                                 session,
-                                 dag=self,
-                                 )
-            if reset_dag_runs:
-                self.set_dag_runs_state(session=session,
-                                        start_date=start_date,
-                                        end_date=end_date,
-                                        state=State.NONE,
-                                        )
+            clear_task_instances(
+                tis,
+                session,
+                dag=self,
+                activate_dag_runs=False,  # We will set DagRun state later.
+            )
+            self.set_dag_runs_state(
+                session=session,
+                start_date=start_date,
+                end_date=end_date,
+                state=dag_run_state,
+            )
         else:
             count = 0
             print("Bail. Nothing was cleared.")
@@ -1161,7 +1162,7 @@ class DAG(BaseDag, LoggingMixin):
             confirm_prompt=False,
             include_subdags=True,
             include_parentdag=False,
-            reset_dag_runs=True,
+            dag_run_state=State.RUNNING,
             dry_run=False,
     ):
         all_tis = []
@@ -1174,7 +1175,7 @@ class DAG(BaseDag, LoggingMixin):
                 confirm_prompt=False,
                 include_subdags=include_subdags,
                 include_parentdag=include_parentdag,
-                reset_dag_runs=reset_dag_runs,
+                dag_run_state=dag_run_state,
                 dry_run=True)
             all_tis.extend(tis)
 
@@ -1202,7 +1203,7 @@ class DAG(BaseDag, LoggingMixin):
                           only_running=only_running,
                           confirm_prompt=False,
                           include_subdags=include_subdags,
-                          reset_dag_runs=reset_dag_runs,
+                          dag_run_state=dag_run_state,
                           dry_run=False,
                           )
         else:
diff --git a/airflow/providers/google/cloud/example_dags/example_datafusion.py 
b/airflow/providers/google/cloud/example_dags/example_datafusion.py
index e2b686c..62ab1d4 100644
--- a/airflow/providers/google/cloud/example_dags/example_datafusion.py
+++ b/airflow/providers/google/cloud/example_dags/example_datafusion.py
@@ -29,6 +29,7 @@ from airflow.providers.google.cloud.operators.datafusion 
import (
     CloudDataFusionStopPipelineOperator, CloudDataFusionUpdateInstanceOperator,
 )
 from airflow.utils import dates
+from airflow.utils.state import State
 
 # [START howto_data_fusion_env_variables]
 LOCATION = "europe-north1"
@@ -227,5 +228,5 @@ with models.DAG(
     delete_pipeline >> delete_instance
 
 if __name__ == "__main__":
-    dag.clear(reset_dag_runs=True)
+    dag.clear(dag_run_state=State.NONE)
     dag.run()
diff --git a/airflow/providers/google/cloud/example_dags/example_gcs.py 
b/airflow/providers/google/cloud/example_dags/example_gcs.py
index 4cdac36..18f173f 100644
--- a/airflow/providers/google/cloud/example_dags/example_gcs.py
+++ b/airflow/providers/google/cloud/example_dags/example_gcs.py
@@ -32,6 +32,7 @@ from airflow.providers.google.cloud.transfers.gcs_to_gcs 
import GCSToGCSOperator
 from airflow.providers.google.cloud.transfers.gcs_to_local import 
GCSToLocalFilesystemOperator
 from airflow.providers.google.cloud.transfers.local_to_gcs import 
LocalFilesystemToGCSOperator
 from airflow.utils.dates import days_ago
+from airflow.utils.state import State
 
 default_args = {"start_date": days_ago(1)}
 
@@ -155,5 +156,5 @@ with models.DAG(
 
 
 if __name__ == '__main__':
-    dag.clear(reset_dag_runs=True)
+    dag.clear(dag_run_state=State.NONE)
     dag.run()
diff --git 
a/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py
 
b/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py
index ca82c93..74fb6d3 100644
--- 
a/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py
+++ 
b/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py
@@ -31,6 +31,7 @@ from 
airflow.providers.google.marketing_platform.sensors.campaign_manager import
     GoogleCampaignManagerReportSensor,
 )
 from airflow.utils import dates
+from airflow.utils.state import State
 
 PROFILE_ID = os.environ.get("MARKETING_PROFILE_ID", "123456789")
 FLOODLIGHT_ACTIVITY_ID = os.environ.get("FLOODLIGHT_ACTIVITY_ID", 12345)
@@ -157,5 +158,5 @@ with models.DAG(
     insert_conversion >> update_conversion
 
 if __name__ == "__main__":
-    dag.clear(reset_dag_runs=True)
+    dag.clear(dag_run_state=State.NONE)
     dag.run()
diff --git a/tests/cli/commands/test_dag_command.py 
b/tests/cli/commands/test_dag_command.py
index fa4a32e..6dda923 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -433,7 +433,8 @@ class TestCliDags(unittest.TestCase):
                 subdir=cli_args.subdir, dag_id='example_bash_operator'
             ),
             mock.call().clear(
-                start_date=cli_args.execution_date, 
end_date=cli_args.execution_date, reset_dag_runs=True
+                start_date=cli_args.execution_date, 
end_date=cli_args.execution_date,
+                dag_run_state=State.NONE,
             ),
             mock.call().run(
                 executor=mock_executor.return_value,
@@ -461,7 +462,9 @@ class TestCliDags(unittest.TestCase):
                 subdir=cli_args.subdir, dag_id='example_bash_operator'
             ),
             mock.call().clear(
-                start_date=cli_args.execution_date, 
end_date=cli_args.execution_date, reset_dag_runs=True
+                start_date=cli_args.execution_date,
+                end_date=cli_args.execution_date,
+                dag_run_state=State.NONE,
             ),
             mock.call().run(
                 executor=mock_executor.return_value,
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 8207b24..8891d56 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -25,6 +25,7 @@ import re
 import unittest
 from contextlib import redirect_stdout
 from tempfile import NamedTemporaryFile
+from typing import Optional
 from unittest import mock
 from unittest.mock import patch
 
@@ -55,6 +56,12 @@ from tests.test_utils.db import clear_db_dags, clear_db_runs
 
 class TestDag(unittest.TestCase):
 
+    def setUp(self) -> None:
+        clear_db_runs()
+
+    def tearDown(self) -> None:
+        clear_db_runs()
+
     @staticmethod
     def _clean_up(dag_id: str):
         with create_session() as session:
@@ -1355,8 +1362,14 @@ class TestDag(unittest.TestCase):
         dr = dag.create_dagrun(run_id="custom_is_set_to_manual", 
state=State.NONE)
         assert dr.run_type == DagRunType.MANUAL.value
 
-    def test_clear_reset_dagruns(self):
-        dag_id = 'test_clear_dag_reset_dagruns'
+    @parameterized.expand(
+        [
+            (State.NONE,),
+            (State.RUNNING,),
+        ]
+    )
+    def test_clear_set_dagrun_state(self, dag_run_state):
+        dag_id = 'test_clear_set_dagrun_state'
         self._clean_up(dag_id)
         task_id = 't1'
         dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
@@ -1365,7 +1378,7 @@ class TestDag(unittest.TestCase):
         session = settings.Session()
         dagrun_1 = dag.create_dagrun(
             run_type=DagRunType.BACKFILL_JOB,
-            state=State.RUNNING,
+            state=State.FAILED,
             start_date=DEFAULT_DATE,
             execution_date=DEFAULT_DATE,
         )
@@ -1378,7 +1391,7 @@ class TestDag(unittest.TestCase):
         dag.clear(
             start_date=DEFAULT_DATE,
             end_date=DEFAULT_DATE + datetime.timedelta(days=1),
-            reset_dag_runs=True,
+            dag_run_state=dag_run_state,
             include_subdags=False,
             include_parentdag=False,
             session=session,
@@ -1392,17 +1405,48 @@ class TestDag(unittest.TestCase):
 
         self.assertEqual(len(dagruns), 1)
         dagrun = dagruns[0]  # type: DagRun
-        self.assertEqual(dagrun.state, State.NONE)
+        self.assertEqual(dagrun.state, dag_run_state)
+
+    @parameterized.expand([
+        (state, State.NONE)
+        for state in State.task_states if state != State.RUNNING
+    ] + [(State.RUNNING, State.SHUTDOWN)])  # type: ignore
+    def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]):
+        dag_id = 'test_clear_dag'
+        self._clean_up(dag_id)
+        task_id = 't1'
+        dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
+        t_1 = DummyOperator(task_id=task_id, dag=dag)
+
+        session = settings.Session()  # type: ignore
+        dagrun_1 = dag.create_dagrun(
+            run_type=DagRunType.BACKFILL_JOB,
+            state=State.RUNNING,
+            start_date=DEFAULT_DATE,
+            execution_date=DEFAULT_DATE,
+        )
+        session.merge(dagrun_1)
+
+        task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, 
state=ti_state_begin)
+        task_instance_1.job_id = 123
+        session.merge(task_instance_1)
+        session.commit()
+
+        dag.clear(
+            start_date=DEFAULT_DATE,
+            end_date=DEFAULT_DATE + datetime.timedelta(days=1),
+            session=session,
+        )
 
         task_instances = session.query(
-            DagRun,
+            TI,
         ).filter(
-            DagRun.dag_id == dag_id,
+            TI.dag_id == dag_id,
         ).all()
 
         self.assertEqual(len(task_instances), 1)
         task_instance = task_instances[0]  # type: TI
-        self.assertEqual(task_instance.state, State.NONE)
+        self.assertEqual(task_instance.state, ti_state_end)
         self._clean_up(dag_id)
 
 

Reply via email to