This is an automated email from the ASF dual-hosted git repository.
milton0825 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new b01d95e Change DAG.clear to take dag_run_state (#9824)
b01d95e is described below
commit b01d95ec22b01ed79123178acd74ef40d57aaa7c
Author: Chao-Han Tsai <[email protected]>
AuthorDate: Wed Jul 15 13:08:18 2020 -0700
Change DAG.clear to take dag_run_state (#9824)
* Change DAG.clear to take dag_run_state
* fix lint
* fix tests
* assign var
* extend original clause
---
airflow/cli/commands/dag_command.py | 4 +-
airflow/models/dag.py | 55 ++++++++++----------
.../cloud/example_dags/example_datafusion.py | 3 +-
.../google/cloud/example_dags/example_gcs.py | 3 +-
.../example_dags/example_campaign_manager.py | 3 +-
tests/cli/commands/test_dag_command.py | 7 ++-
tests/models/test_dag.py | 60 +++++++++++++++++++---
7 files changed, 94 insertions(+), 41 deletions(-)
diff --git a/airflow/cli/commands/dag_command.py
b/airflow/cli/commands/dag_command.py
index 685ca21..55c40b4 100644
--- a/airflow/cli/commands/dag_command.py
+++ b/airflow/cli/commands/dag_command.py
@@ -39,6 +39,7 @@ from airflow.utils import cli as cli_utils
from airflow.utils.cli import get_dag, get_dag_by_file_location,
process_subdir, sigint_handler
from airflow.utils.dot_renderer import render_dag
from airflow.utils.session import create_session, provide_session
+from airflow.utils.state import State
def _tabulate_dag_runs(dag_runs: List[DagRun], tablefmt: str = "fancy_grid")
-> str:
@@ -123,6 +124,7 @@ def dag_backfill(args, dag=None):
end_date=args.end_date,
confirm_prompt=not args.yes,
include_subdags=True,
+ dag_run_state=State.NONE,
)
dag.run(
@@ -381,7 +383,7 @@ def dag_list_dag_runs(args, dag=None):
def dag_test(args, session=None):
"""Execute one single DagRun for a given DAG and execution date, using the
DebugExecutor."""
dag = get_dag(subdir=args.subdir, dag_id=args.dag_id)
- dag.clear(start_date=args.execution_date, end_date=args.execution_date,
reset_dag_runs=True)
+ dag.clear(start_date=args.execution_date, end_date=args.execution_date,
dag_run_state=State.NONE)
try:
dag.run(executor=DebugExecutor(), start_date=args.execution_date,
end_date=args.execution_date)
except BackfillUnfinished as e:
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index e6aafd3..dfb6409 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -27,7 +27,7 @@ import traceback
import warnings
from collections import OrderedDict
from datetime import datetime, timedelta
-from typing import Callable, Collection, Dict, FrozenSet, Iterable, List,
Optional, Set, Type, Union
+from typing import Callable, Collection, Dict, FrozenSet, Iterable, List,
Optional, Set, Type, Union, cast
import jinja2
import pendulum
@@ -297,7 +297,7 @@ class DAG(BaseDag, LoggingMixin):
template_searchpath = [template_searchpath]
self.template_searchpath = template_searchpath
self.template_undefined = template_undefined
- self.parent_dag = None # Gets set when DAGs are loaded
+ self.parent_dag: Optional[DAG] = None # Gets set when DAGs are loaded
self.last_loaded = timezone.utcnow()
self.safe_dag_id = dag_id.replace('.', '__dot__')
self.max_active_runs = max_active_runs
@@ -966,7 +966,7 @@ class DAG(BaseDag, LoggingMixin):
confirm_prompt=False,
include_subdags=True,
include_parentdag=True,
- reset_dag_runs=True,
+ dag_run_state: str = State.RUNNING,
dry_run=False,
session=None,
get_tis=False,
@@ -993,8 +993,7 @@ class DAG(BaseDag, LoggingMixin):
:type include_subdags: bool
:param include_parentdag: Clear tasks in the parent dag of the subdag.
:type include_parentdag: bool
- :param reset_dag_runs: Set state of dag to RUNNING
- :type reset_dag_runs: bool
+ :param dag_run_state: state to set DagRun to
:param dry_run: Find the tasks to clear but don't clear them.
:type dry_run: bool
:param session: The sqlalchemy session to use
@@ -1025,8 +1024,7 @@ class DAG(BaseDag, LoggingMixin):
tis = session.query(TI).filter(TI.dag_id == self.dag_id)
tis = tis.filter(TI.task_id.in_(self.task_ids))
- if include_parentdag and self.is_subdag:
-
+ if include_parentdag and self.is_subdag and self.parent_dag is not
None:
p_dag = self.parent_dag.sub_dag(
task_regex=r"^{}$".format(self.dag_id.split('.')[1]),
include_upstream=False,
@@ -1039,7 +1037,7 @@ class DAG(BaseDag, LoggingMixin):
confirm_prompt=confirm_prompt,
include_subdags=include_subdags,
include_parentdag=False,
- reset_dag_runs=reset_dag_runs,
+ dag_run_state=dag_run_state,
get_tis=True,
session=session,
recursion_depth=recursion_depth,
@@ -1065,12 +1063,13 @@ class DAG(BaseDag, LoggingMixin):
instances = tis.all()
for ti in instances:
if ti.operator == ExternalTaskMarker.__name__:
- ti.task = self.get_task(ti.task_id)
+ task: ExternalTaskMarker = cast(ExternalTaskMarker,
self.get_task(ti.task_id))
+ ti.task = task
if recursion_depth == 0:
# Maximum recursion depth allowed is the
recursion_depth of the first
# ExternalTaskMarker in the tasks to be cleared.
- max_recursion_depth = ti.task.recursion_depth
+ max_recursion_depth = task.recursion_depth
if recursion_depth + 1 > max_recursion_depth:
# Prevent cycles or accidents.
@@ -1080,10 +1079,10 @@ class DAG(BaseDag, LoggingMixin):
.format(max_recursion_depth,
ExternalTaskMarker.__name__, ti.task_id))
ti.render_templates()
- external_tis = session.query(TI).filter(TI.dag_id ==
ti.task.external_dag_id,
- TI.task_id ==
ti.task.external_task_id,
+ external_tis = session.query(TI).filter(TI.dag_id ==
task.external_dag_id,
+ TI.task_id ==
task.external_task_id,
TI.execution_date
==
-
pendulum.parse(ti.task.execution_date))
+
pendulum.parse(task.execution_date))
for tii in external_tis:
if not dag_bag:
@@ -1103,7 +1102,7 @@ class DAG(BaseDag, LoggingMixin):
confirm_prompt=confirm_prompt,
include_subdags=include_subdags,
include_parentdag=False,
-
reset_dag_runs=reset_dag_runs,
+
dag_run_state=dag_run_state,
get_tis=True,
session=session,
recursion_depth=recursion_depth + 1,
@@ -1134,16 +1133,18 @@ class DAG(BaseDag, LoggingMixin):
do_it = utils.helpers.ask_yesno(question)
if do_it:
- clear_task_instances(tis,
- session,
- dag=self,
- )
- if reset_dag_runs:
- self.set_dag_runs_state(session=session,
- start_date=start_date,
- end_date=end_date,
- state=State.NONE,
- )
+ clear_task_instances(
+ tis,
+ session,
+ dag=self,
+ activate_dag_runs=False, # We will set DagRun state later.
+ )
+ self.set_dag_runs_state(
+ session=session,
+ start_date=start_date,
+ end_date=end_date,
+ state=dag_run_state,
+ )
else:
count = 0
print("Bail. Nothing was cleared.")
@@ -1161,7 +1162,7 @@ class DAG(BaseDag, LoggingMixin):
confirm_prompt=False,
include_subdags=True,
include_parentdag=False,
- reset_dag_runs=True,
+ dag_run_state=State.RUNNING,
dry_run=False,
):
all_tis = []
@@ -1174,7 +1175,7 @@ class DAG(BaseDag, LoggingMixin):
confirm_prompt=False,
include_subdags=include_subdags,
include_parentdag=include_parentdag,
- reset_dag_runs=reset_dag_runs,
+ dag_run_state=dag_run_state,
dry_run=True)
all_tis.extend(tis)
@@ -1202,7 +1203,7 @@ class DAG(BaseDag, LoggingMixin):
only_running=only_running,
confirm_prompt=False,
include_subdags=include_subdags,
- reset_dag_runs=reset_dag_runs,
+ dag_run_state=dag_run_state,
dry_run=False,
)
else:
diff --git a/airflow/providers/google/cloud/example_dags/example_datafusion.py
b/airflow/providers/google/cloud/example_dags/example_datafusion.py
index e2b686c..62ab1d4 100644
--- a/airflow/providers/google/cloud/example_dags/example_datafusion.py
+++ b/airflow/providers/google/cloud/example_dags/example_datafusion.py
@@ -29,6 +29,7 @@ from airflow.providers.google.cloud.operators.datafusion
import (
CloudDataFusionStopPipelineOperator, CloudDataFusionUpdateInstanceOperator,
)
from airflow.utils import dates
+from airflow.utils.state import State
# [START howto_data_fusion_env_variables]
LOCATION = "europe-north1"
@@ -227,5 +228,5 @@ with models.DAG(
delete_pipeline >> delete_instance
if __name__ == "__main__":
- dag.clear(reset_dag_runs=True)
+ dag.clear(dag_run_state=State.NONE)
dag.run()
diff --git a/airflow/providers/google/cloud/example_dags/example_gcs.py
b/airflow/providers/google/cloud/example_dags/example_gcs.py
index 4cdac36..18f173f 100644
--- a/airflow/providers/google/cloud/example_dags/example_gcs.py
+++ b/airflow/providers/google/cloud/example_dags/example_gcs.py
@@ -32,6 +32,7 @@ from airflow.providers.google.cloud.transfers.gcs_to_gcs
import GCSToGCSOperator
from airflow.providers.google.cloud.transfers.gcs_to_local import
GCSToLocalFilesystemOperator
from airflow.providers.google.cloud.transfers.local_to_gcs import
LocalFilesystemToGCSOperator
from airflow.utils.dates import days_ago
+from airflow.utils.state import State
default_args = {"start_date": days_ago(1)}
@@ -155,5 +156,5 @@ with models.DAG(
if __name__ == '__main__':
- dag.clear(reset_dag_runs=True)
+ dag.clear(dag_run_state=State.NONE)
dag.run()
diff --git
a/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py
b/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py
index ca82c93..74fb6d3 100644
---
a/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py
+++
b/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py
@@ -31,6 +31,7 @@ from
airflow.providers.google.marketing_platform.sensors.campaign_manager import
GoogleCampaignManagerReportSensor,
)
from airflow.utils import dates
+from airflow.utils.state import State
PROFILE_ID = os.environ.get("MARKETING_PROFILE_ID", "123456789")
FLOODLIGHT_ACTIVITY_ID = os.environ.get("FLOODLIGHT_ACTIVITY_ID", 12345)
@@ -157,5 +158,5 @@ with models.DAG(
insert_conversion >> update_conversion
if __name__ == "__main__":
- dag.clear(reset_dag_runs=True)
+ dag.clear(dag_run_state=State.NONE)
dag.run()
diff --git a/tests/cli/commands/test_dag_command.py
b/tests/cli/commands/test_dag_command.py
index fa4a32e..6dda923 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -433,7 +433,8 @@ class TestCliDags(unittest.TestCase):
subdir=cli_args.subdir, dag_id='example_bash_operator'
),
mock.call().clear(
- start_date=cli_args.execution_date,
end_date=cli_args.execution_date, reset_dag_runs=True
+ start_date=cli_args.execution_date,
end_date=cli_args.execution_date,
+ dag_run_state=State.NONE,
),
mock.call().run(
executor=mock_executor.return_value,
@@ -461,7 +462,9 @@ class TestCliDags(unittest.TestCase):
subdir=cli_args.subdir, dag_id='example_bash_operator'
),
mock.call().clear(
- start_date=cli_args.execution_date,
end_date=cli_args.execution_date, reset_dag_runs=True
+ start_date=cli_args.execution_date,
+ end_date=cli_args.execution_date,
+ dag_run_state=State.NONE,
),
mock.call().run(
executor=mock_executor.return_value,
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 8207b24..8891d56 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -25,6 +25,7 @@ import re
import unittest
from contextlib import redirect_stdout
from tempfile import NamedTemporaryFile
+from typing import Optional
from unittest import mock
from unittest.mock import patch
@@ -55,6 +56,12 @@ from tests.test_utils.db import clear_db_dags, clear_db_runs
class TestDag(unittest.TestCase):
+ def setUp(self) -> None:
+ clear_db_runs()
+
+ def tearDown(self) -> None:
+ clear_db_runs()
+
@staticmethod
def _clean_up(dag_id: str):
with create_session() as session:
@@ -1355,8 +1362,14 @@ class TestDag(unittest.TestCase):
dr = dag.create_dagrun(run_id="custom_is_set_to_manual",
state=State.NONE)
assert dr.run_type == DagRunType.MANUAL.value
- def test_clear_reset_dagruns(self):
- dag_id = 'test_clear_dag_reset_dagruns'
+ @parameterized.expand(
+ [
+ (State.NONE,),
+ (State.RUNNING,),
+ ]
+ )
+ def test_clear_set_dagrun_state(self, dag_run_state):
+ dag_id = 'test_clear_set_dagrun_state'
self._clean_up(dag_id)
task_id = 't1'
dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
@@ -1365,7 +1378,7 @@ class TestDag(unittest.TestCase):
session = settings.Session()
dagrun_1 = dag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
- state=State.RUNNING,
+ state=State.FAILED,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
@@ -1378,7 +1391,7 @@ class TestDag(unittest.TestCase):
dag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
- reset_dag_runs=True,
+ dag_run_state=dag_run_state,
include_subdags=False,
include_parentdag=False,
session=session,
@@ -1392,17 +1405,48 @@ class TestDag(unittest.TestCase):
self.assertEqual(len(dagruns), 1)
dagrun = dagruns[0] # type: DagRun
- self.assertEqual(dagrun.state, State.NONE)
+ self.assertEqual(dagrun.state, dag_run_state)
+
+ @parameterized.expand([
+ (state, State.NONE)
+ for state in State.task_states if state != State.RUNNING
+ ] + [(State.RUNNING, State.SHUTDOWN)]) # type: ignore
+ def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]):
+ dag_id = 'test_clear_dag'
+ self._clean_up(dag_id)
+ task_id = 't1'
+ dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
+ t_1 = DummyOperator(task_id=task_id, dag=dag)
+
+ session = settings.Session() # type: ignore
+ dagrun_1 = dag.create_dagrun(
+ run_type=DagRunType.BACKFILL_JOB,
+ state=State.RUNNING,
+ start_date=DEFAULT_DATE,
+ execution_date=DEFAULT_DATE,
+ )
+ session.merge(dagrun_1)
+
+ task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE,
state=ti_state_begin)
+ task_instance_1.job_id = 123
+ session.merge(task_instance_1)
+ session.commit()
+
+ dag.clear(
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=1),
+ session=session,
+ )
task_instances = session.query(
- DagRun,
+ TI,
).filter(
- DagRun.dag_id == dag_id,
+ TI.dag_id == dag_id,
).all()
self.assertEqual(len(task_instances), 1)
task_instance = task_instances[0] # type: TI
- self.assertEqual(task_instance.state, State.NONE)
+ self.assertEqual(task_instance.state, ti_state_end)
self._clean_up(dag_id)