This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new aa7a3b24fd0 Remove logical_date from APIs and Functions, use run_id
instead (#42404)
aa7a3b24fd0 is described below
commit aa7a3b24fd0bc267ffca67c740fbe09f892ae39f
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Wed Nov 20 19:46:20 2024 +0545
Remove logical_date from APIs and Functions, use run_id instead (#42404)
Co-authored-by: Tzu-ping Chung <[email protected]>
---
airflow/api/common/mark_tasks.py | 57 --------
airflow/api/common/trigger_dag.py | 4 +-
.../endpoints/task_instance_endpoint.py | 16 +--
.../api_connexion/schemas/task_instance_schema.py | 10 +-
airflow/cli/commands/task_command.py | 70 +++++++---
airflow/exceptions.py | 11 +-
airflow/models/dag.py | 145 ++++++++++-----------
airflow/models/dagrun.py | 22 +---
airflow/www/views.py | 36 +++--
newsfragments/42404.significant.rst | 6 +
.../endpoints/test_task_instance_endpoint.py | 68 ++--------
.../schemas/test_task_instance_schema.py | 6 +-
tests/cli/commands/test_task_command.py | 26 ++--
tests/models/test_dag.py | 12 +-
tests/models/test_dagrun.py | 7 +-
tests/operators/test_trigger_dagrun.py | 19 +--
tests/sensors/test_external_task_sensor.py | 26 +++-
17 files changed, 231 insertions(+), 310 deletions(-)
diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index a170e6901a5..b57d25498d2 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -27,7 +27,6 @@ from sqlalchemy.orm import lazyload
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
-from airflow.utils.helpers import exactly_one
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -87,7 +86,6 @@ def set_state(
*,
tasks: Collection[Operator | tuple[Operator, int]],
run_id: str | None = None,
- logical_date: datetime | None = None,
upstream: bool = False,
downstream: bool = False,
future: bool = False,
@@ -107,7 +105,6 @@ def set_state(
:param tasks: the iterable of tasks or (task, map_index) tuples from which
to work.
``task.dag`` needs to be set
:param run_id: the run_id of the dagrun to start looking from
- :param logical_date: the logical date from which to start looking
(deprecated)
:param upstream: Mark all parents (upstream tasks)
:param downstream: Mark all siblings (downstream tasks) of task_id
:param future: Mark all future tasks on the interval of the dag up until
@@ -121,21 +118,12 @@ def set_state(
if not tasks:
return []
- if not exactly_one(logical_date, run_id):
- raise ValueError("Exactly one of dag_run_id and logical_date must be
set")
-
- if logical_date and not timezone.is_localized(logical_date):
- raise ValueError(f"Received non-localized date {logical_date}")
-
task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task
in tasks}
if len(task_dags) > 1:
raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
dag = next(iter(task_dags))
if dag is None:
raise ValueError("Received tasks with no DAG")
-
- if logical_date:
- run_id = dag.get_dagrun(logical_date=logical_date,
session=session).run_id
if not run_id:
raise ValueError("Received tasks with no run_id")
@@ -279,7 +267,6 @@ def _set_dag_run_state(dag_id: str, run_id: str, state:
DagRunState, session: SA
def set_dag_run_state_to_success(
*,
dag: DAG,
- logical_date: datetime | None = None,
run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
@@ -290,7 +277,6 @@ def set_dag_run_state_to_success(
Set for a specific logical date and its task instances to success.
:param dag: the DAG of which to alter state
- :param logical_date: the logical date from which to start
looking(deprecated)
:param run_id: the run_id to start looking from
:param commit: commit DAG and tasks to be altered to the database
:param session: database session
@@ -298,19 +284,8 @@ def set_dag_run_state_to_success(
otherwise list of tasks that will be updated
:raises: ValueError if dag or logical_date is invalid
"""
- if not exactly_one(logical_date, run_id):
- return []
-
if not dag:
return []
-
- if logical_date:
- if not timezone.is_localized(logical_date):
- raise ValueError(f"Received non-localized date {logical_date}")
- dag_run = dag.get_dagrun(logical_date=logical_date)
- if not dag_run:
- raise ValueError(f"DagRun with logical_date: {logical_date} not
found")
- run_id = dag_run.run_id
if not run_id:
raise ValueError(f"Invalid dag_run_id: {run_id}")
# Mark the dag run to success.
@@ -333,7 +308,6 @@ def set_dag_run_state_to_success(
def set_dag_run_state_to_failed(
*,
dag: DAG,
- logical_date: datetime | None = None,
run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
@@ -344,27 +318,14 @@ def set_dag_run_state_to_failed(
Set for a specific logical date and its task instances to failed.
:param dag: the DAG of which to alter state
- :param logical_date: the logical date from which to start
looking(deprecated)
:param run_id: the DAG run_id to start looking from
:param commit: commit DAG and tasks to be altered to the database
:param session: database session
:return: If commit is true, list of tasks that have been updated,
otherwise list of tasks that will be updated
- :raises: AssertionError if dag or logical_date is invalid
"""
- if not exactly_one(logical_date, run_id):
- return []
if not dag:
return []
-
- if logical_date:
- if not timezone.is_localized(logical_date):
- raise ValueError(f"Received non-localized date {logical_date}")
- dag_run = dag.get_dagrun(logical_date=logical_date)
- if not dag_run:
- raise ValueError(f"DagRun with logical_date: {logical_date} not
found")
- run_id = dag_run.run_id
-
if not run_id:
raise ValueError(f"Invalid dag_run_id: {run_id}")
@@ -429,7 +390,6 @@ def __set_dag_run_state_to_running_or_queued(
*,
new_state: DagRunState,
dag: DAG,
- logical_date: datetime | None = None,
run_id: str | None = None,
commit: bool = False,
session: SASession,
@@ -438,7 +398,6 @@ def __set_dag_run_state_to_running_or_queued(
Set the dag run for a specific logical date to running.
:param dag: the DAG of which to alter state
- :param logical_date: the logical date from which to start looking
:param run_id: the id of the DagRun
:param commit: commit DAG and tasks to be altered to the database
:param session: database session
@@ -446,20 +405,8 @@ def __set_dag_run_state_to_running_or_queued(
otherwise list of tasks that will be updated
"""
res: list[TaskInstance] = []
-
- if not exactly_one(logical_date, run_id):
- return res
-
if not dag:
return res
-
- if logical_date:
- if not timezone.is_localized(logical_date):
- raise ValueError(f"Received non-localized date {logical_date}")
- dag_run = dag.get_dagrun(logical_date=logical_date)
- if not dag_run:
- raise ValueError(f"DagRun with logical_date: {logical_date} not
found")
- run_id = dag_run.run_id
if not run_id:
raise ValueError(f"DagRun with run_id: {run_id} not found")
# Mark the dag run to running.
@@ -474,7 +421,6 @@ def __set_dag_run_state_to_running_or_queued(
def set_dag_run_state_to_running(
*,
dag: DAG,
- logical_date: datetime | None = None,
run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
@@ -487,7 +433,6 @@ def set_dag_run_state_to_running(
return __set_dag_run_state_to_running_or_queued(
new_state=DagRunState.RUNNING,
dag=dag,
- logical_date=logical_date,
run_id=run_id,
commit=commit,
session=session,
@@ -498,7 +443,6 @@ def set_dag_run_state_to_running(
def set_dag_run_state_to_queued(
*,
dag: DAG,
- logical_date: datetime | None = None,
run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
@@ -511,7 +455,6 @@ def set_dag_run_state_to_queued(
return __set_dag_run_state_to_running_or_queued(
new_state=DagRunState.QUEUED,
dag=dag,
- logical_date=logical_date,
run_id=run_id,
commit=commit,
session=session,
diff --git a/airflow/api/common/trigger_dag.py
b/airflow/api/common/trigger_dag.py
index 4a94f990191..6891cc1df78 100644
--- a/airflow/api/common/trigger_dag.py
+++ b/airflow/api/common/trigger_dag.py
@@ -85,10 +85,10 @@ def _trigger_dag(
run_id = run_id or dag.timetable.generate_run_id(
run_type=DagRunType.MANUAL, logical_date=coerced_logical_date,
data_interval=data_interval
)
- dag_run = DagRun.find_duplicate(dag_id=dag_id, run_id=run_id,
logical_date=logical_date)
+ dag_run = DagRun.find_duplicate(dag_id=dag_id, run_id=run_id)
if dag_run:
- raise DagRunAlreadyExists(dag_run, logical_date=logical_date,
run_id=run_id)
+ raise DagRunAlreadyExists(dag_run)
run_conf = None
if conf:
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 7f43c160f32..00eb51bae10 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -522,19 +522,10 @@ def post_set_task_instances_state(*, dag_id: str,
session: Session = NEW_SESSION
if not task:
error_message = f"Task ID {task_id} not found"
raise NotFound(error_message)
-
- logical_date = data.get("logical_date")
run_id = data.get("dag_run_id")
- if (
- logical_date
- and (
- session.scalars(
- select(TI).where(TI.task_id == task_id, TI.dag_id == dag_id,
TI.logical_date == logical_date)
- ).one_or_none()
- )
- is None
- ):
- raise NotFound(detail=f"Task instance not found for task {task_id!r}
on logical_date {logical_date}")
+ if not run_id:
+ error_message = f"Task instance not found for task {task_id!r} on DAG
run with ID {run_id!r}"
+ raise NotFound(detail=error_message)
select_stmt = select(TI).where(
TI.dag_id == dag_id, TI.task_id == task_id, TI.run_id == run_id,
TI.map_index == -1
@@ -547,7 +538,6 @@ def post_set_task_instances_state(*, dag_id: str, session:
Session = NEW_SESSION
tis = dag.set_task_instance_state(
task_id=task_id,
run_id=run_id,
- logical_date=logical_date,
state=data["new_state"],
upstream=data["include_upstream"],
downstream=data["include_downstream"],
diff --git a/airflow/api_connexion/schemas/task_instance_schema.py
b/airflow/api_connexion/schemas/task_instance_schema.py
index 3e864f18652..360ecdf277e 100644
--- a/airflow/api_connexion/schemas/task_instance_schema.py
+++ b/airflow/api_connexion/schemas/task_instance_schema.py
@@ -29,7 +29,6 @@ from airflow.api_connexion.schemas.job_schema import JobSchema
from airflow.api_connexion.schemas.trigger_schema import TriggerSchema
from airflow.models import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
-from airflow.utils.helpers import exactly_one
from airflow.utils.state import TaskInstanceState
@@ -196,8 +195,7 @@ class SetTaskInstanceStateFormSchema(Schema):
dry_run = fields.Boolean(load_default=True)
task_id = fields.Str(required=True)
- logical_date = fields.DateTime(validate=validate_istimezone)
- dag_run_id = fields.Str()
+ dag_run_id = fields.Str(required=True)
include_upstream = fields.Boolean(required=True)
include_downstream = fields.Boolean(required=True)
include_future = fields.Boolean(required=True)
@@ -209,12 +207,6 @@ class SetTaskInstanceStateFormSchema(Schema):
),
)
- @validates_schema
- def validate_form(self, data, **kwargs):
- """Validate set task instance state form."""
- if not exactly_one(data.get("logical_date"), data.get("dag_run_id")):
- raise ValidationError("Exactly one of logical_date or dag_run_id
must be provided")
-
class SetSingleTaskInstanceStateFormSchema(Schema):
"""Schema for handling the request of updating state of a single task
instance."""
diff --git a/airflow/cli/commands/task_command.py
b/airflow/cli/commands/task_command.py
index 396186bf14d..2b5a6c18a80 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -91,11 +91,46 @@ def _generate_temporary_run_id() -> str:
return f"__airflow_temporary_run_{timezone.utcnow().isoformat()}__"
+def _fetch_dag_run_from_run_id_or_logical_date_string(
+ *,
+ dag_id: str,
+ value: str,
+ session: Session,
+) -> tuple[DagRun | DagRunPydantic, pendulum.DateTime | None]:
+ """
+ Try to find a DAG run with a given string value.
+
+ The string value may be a run ID, or a logical date in string form. We
first
+ try to use it as a run_id; if a run is found, it is returned as-is.
+
+ Otherwise, the string value is parsed into a datetime. If that works, it is
+ used to find a DAG run.
+
+ The return value is a two-tuple. The first item is the found DAG run (or
+ *None* if one cannot be found). The second is the parsed logical date. This
+ second value can be used to create a new run by the calling function when
+ one cannot be found here.
+ """
+ if dag_run := DAG.fetch_dagrun(dag_id=dag_id, run_id=value,
session=session):
+ return dag_run, dag_run.logical_date # type: ignore[return-value]
+ try:
+ logical_date = timezone.parse(value)
+ except (ParserError, TypeError):
+ return dag_run, None
+ dag_run = session.scalar(
+ select(DagRun)
+ .where(DagRun.dag_id == dag_id, DagRun.logical_date == logical_date)
+ .order_by(DagRun.id.desc())
+ .limit(1)
+ )
+ return dag_run, logical_date
+
+
def _get_dag_run(
*,
dag: DAG,
create_if_necessary: CreateIfNecessary,
- exec_date_or_run_id: str | None = None,
+ logical_date_or_run_id: str | None = None,
session: Session | None = None,
) -> tuple[DagRun | DagRunPydantic, bool]:
"""
@@ -103,7 +138,7 @@ def _get_dag_run(
This checks DAG runs like this:
- 1. If the input ``exec_date_or_run_id`` matches a DAG run ID, return the
run.
+ 1. If the input ``logical_date_or_run_id`` matches a DAG run ID, return
the run.
2. Try to parse the input as a date. If that works, and the resulting
date matches a DAG run's logical date, return the run.
3. If ``create_if_necessary`` is *False* and the input works for neither of
@@ -112,23 +147,22 @@ def _get_dag_run(
the logical date; otherwise use it as a run ID and set the logical date
to the current time.
"""
- if not exec_date_or_run_id and not create_if_necessary:
- raise ValueError("Must provide `exec_date_or_run_id` if not
`create_if_necessary`.")
- logical_date: pendulum.DateTime | None = None
- if exec_date_or_run_id:
- dag_run = DAG.fetch_dagrun(dag_id=dag.dag_id,
run_id=exec_date_or_run_id, session=session)
- if dag_run:
- return dag_run, False
- with suppress(ParserError, TypeError):
- logical_date = timezone.parse(exec_date_or_run_id)
- if logical_date:
- dag_run = DAG.fetch_dagrun(dag_id=dag.dag_id,
logical_date=logical_date, session=session)
- if dag_run:
+ if not logical_date_or_run_id and not create_if_necessary:
+ raise ValueError("Must provide `logical_date_or_run_id` if not
`create_if_necessary`.")
+
+ logical_date = None
+ if logical_date_or_run_id:
+ dag_run, logical_date =
_fetch_dag_run_from_run_id_or_logical_date_string(
+ dag_id=dag.dag_id,
+ value=logical_date_or_run_id,
+ session=session,
+ )
+ if dag_run is not None:
return dag_run, False
elif not create_if_necessary:
raise DagRunNotFound(
f"DagRun for {dag.dag_id} with run_id or logical_date "
- f"of {exec_date_or_run_id!r} not found"
+ f"of {logical_date_or_run_id!r} not found"
)
if logical_date is not None:
@@ -139,7 +173,7 @@ def _get_dag_run(
if create_if_necessary == "memory":
dag_run = DagRun(
dag_id=dag.dag_id,
- run_id=exec_date_or_run_id,
+ run_id=logical_date_or_run_id,
logical_date=dag_run_logical_date,
data_interval=dag.timetable.infer_manual_data_interval(run_after=dag_run_logical_date),
triggered_by=DagRunTriggeredByType.CLI,
@@ -178,7 +212,7 @@ def _get_ti_db_access(
raise ValueError(f"Provided task {task.task_id} is not in dag
'{dag.dag_id}.")
if not logical_date_or_run_id and not create_if_necessary:
- raise ValueError("Must provide `exec_date_or_run_id` if not
`create_if_necessary`.")
+ raise ValueError("Must provide `logical_date_or_run_id` if not
`create_if_necessary`.")
if task.get_needs_expansion():
if map_index < 0:
raise RuntimeError("No map_index passed to mapped task")
@@ -186,7 +220,7 @@ def _get_ti_db_access(
raise RuntimeError("map_index passed to non-mapped task")
dag_run, dr_created = _get_dag_run(
dag=dag,
- exec_date_or_run_id=logical_date_or_run_id,
+ logical_date_or_run_id=logical_date_or_run_id,
create_if_necessary=create_if_necessary,
session=session,
)
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index 3b07b9a6fda..fee0b5a671d 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -230,13 +230,9 @@ class DagRunNotFound(AirflowNotFoundException):
class DagRunAlreadyExists(AirflowBadRequest):
"""Raise when creating a DAG run for DAG which already has DAG run
entry."""
- def __init__(self, dag_run: DagRun, logical_date: datetime.datetime,
run_id: str) -> None:
- super().__init__(
- f"A DAG Run already exists for DAG {dag_run.dag_id} at
{logical_date} with run id {run_id}"
- )
+ def __init__(self, dag_run: DagRun) -> None:
+ super().__init__(f"A DAG Run already exists for DAG {dag_run.dag_id}
with run id {dag_run.run_id}")
self.dag_run = dag_run
- self.logical_date = logical_date
- self.run_id = run_id
def serialize(self):
cls = self.__class__
@@ -249,13 +245,12 @@ class DagRunAlreadyExists(AirflowBadRequest):
run_id=self.dag_run.run_id,
external_trigger=self.dag_run.external_trigger,
run_type=self.dag_run.run_type,
- logical_date=self.dag_run.logical_date,
)
dag_run.id = self.dag_run.id
return (
f"{cls.__module__}.{cls.__name__}",
(),
- {"dag_run": dag_run, "logical_date": self.logical_date, "run_id":
self.run_id},
+ {"dag_run": dag_run},
)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index d898e1a52f4..8177025c2d8 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -27,6 +27,7 @@ import time
from collections import defaultdict
from contextlib import ExitStack
from datetime import datetime, timedelta
+from functools import cache
from typing import (
TYPE_CHECKING,
Any,
@@ -108,7 +109,6 @@ from airflow.timetables.simple import (
)
from airflow.utils import timezone
from airflow.utils.dag_cycle_tester import check_cycle
-from airflow.utils.helpers import exactly_one
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, lock_rows,
tuple_in_condition, with_row_locks
@@ -879,38 +879,20 @@ class DAG(TaskSDKDag, LoggingMixin):
@staticmethod
@internal_api_call
@provide_session
- def fetch_dagrun(
- dag_id: str,
- logical_date: datetime | None = None,
- run_id: str | None = None,
- session: Session = NEW_SESSION,
- ) -> DagRun | DagRunPydantic:
+ def fetch_dagrun(dag_id: str, run_id: str, session: Session = NEW_SESSION)
-> DagRun | DagRunPydantic:
"""
- Return the dag run for a given logical date or run_id if it exists,
otherwise none.
+ Return the dag run for a given run_id if it exists, otherwise none.
:param dag_id: The dag_id of the DAG to find.
- :param logical_date: The logical date of the DagRun to find.
:param run_id: The run_id of the DagRun to find.
:param session:
:return: The DagRun if found, otherwise None.
"""
- if not (logical_date or run_id):
- raise TypeError("You must provide either the logical_date or the
run_id")
- query = select(DagRun)
- if logical_date:
- query = query.where(DagRun.dag_id == dag_id, DagRun.logical_date
== logical_date)
- if run_id:
- query = query.where(DagRun.dag_id == dag_id, DagRun.run_id ==
run_id)
- return session.scalar(query)
+ return session.scalar(select(DagRun).where(DagRun.dag_id == dag_id,
DagRun.run_id == run_id))
@provide_session
- def get_dagrun(
- self,
- logical_date: datetime | None = None,
- run_id: str | None = None,
- session: Session = NEW_SESSION,
- ) -> DagRun | DagRunPydantic:
- return DAG.fetch_dagrun(dag_id=self.dag_id, logical_date=logical_date,
run_id=run_id, session=session)
+ def get_dagrun(self, run_id: str, session: Session = NEW_SESSION) ->
DagRun | DagRunPydantic:
+ return DAG.fetch_dagrun(dag_id=self.dag_id, run_id=run_id,
session=session)
@provide_session
def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION):
@@ -992,6 +974,7 @@ class DAG(TaskSDKDag, LoggingMixin):
state=state or (),
include_dependent_dags=False,
exclude_task_ids=(),
+ exclude_run_ids=None,
session=session,
)
return session.scalars(cast(Select,
query).order_by(DagRun.logical_date)).all()
@@ -1007,6 +990,7 @@ class DAG(TaskSDKDag, LoggingMixin):
state: TaskInstanceState | Sequence[TaskInstanceState],
include_dependent_dags: bool,
exclude_task_ids: Collection[str | tuple[str, int]] | None,
+ exclude_run_ids: frozenset[str] | None,
session: Session,
dag_bag: DagBag | None = ...,
) -> Iterable[TaskInstance]: ... # pragma: no cover
@@ -1023,6 +1007,7 @@ class DAG(TaskSDKDag, LoggingMixin):
state: TaskInstanceState | Sequence[TaskInstanceState],
include_dependent_dags: bool,
exclude_task_ids: Collection[str | tuple[str, int]] | None,
+ exclude_run_ids: frozenset[str] | None,
session: Session,
dag_bag: DagBag | None = ...,
recursion_depth: int = ...,
@@ -1041,6 +1026,7 @@ class DAG(TaskSDKDag, LoggingMixin):
state: TaskInstanceState | Sequence[TaskInstanceState],
include_dependent_dags: bool,
exclude_task_ids: Collection[str | tuple[str, int]] | None,
+ exclude_run_ids: frozenset[str] | None,
session: Session,
dag_bag: DagBag | None = None,
recursion_depth: int = 0,
@@ -1098,6 +1084,9 @@ class DAG(TaskSDKDag, LoggingMixin):
else:
tis = tis.where(TaskInstance.state.in_(state))
+ if exclude_run_ids:
+ tis = tis.where(not_(TaskInstance.run_id.in_(exclude_run_ids)))
+
if include_dependent_dags:
# Recursively find external tasks indicated by ExternalTaskMarker
from airflow.sensors.external_task import ExternalTaskMarker
@@ -1170,6 +1159,7 @@ class DAG(TaskSDKDag, LoggingMixin):
include_dependent_dags=include_dependent_dags,
as_pk_tuple=True,
exclude_task_ids=exclude_task_ids,
+ exclude_run_ids=exclude_run_ids,
dag_bag=dag_bag,
session=session,
recursion_depth=recursion_depth + 1,
@@ -1216,7 +1206,6 @@ class DAG(TaskSDKDag, LoggingMixin):
*,
task_id: str,
map_indexes: Collection[int] | None = None,
- logical_date: datetime | None = None,
run_id: str | None = None,
state: TaskInstanceState,
upstream: bool = False,
@@ -1232,7 +1221,6 @@ class DAG(TaskSDKDag, LoggingMixin):
:param task_id: Task ID of the TaskInstance
:param map_indexes: Only set TaskInstance if its map_index matches.
If None (default), all mapped TaskInstances of the task are set.
- :param logical_date: Logical date of the TaskInstance
:param run_id: The run_id of the TaskInstance
:param state: State to set the TaskInstance to
:param upstream: Include all upstream tasks of the given task_id
@@ -1243,9 +1231,6 @@ class DAG(TaskSDKDag, LoggingMixin):
"""
from airflow.api.common.mark_tasks import set_state
- if not exactly_one(logical_date, run_id):
- raise ValueError("Exactly one of logical_date or run_id must be
provided")
-
task = self.get_task(task_id)
task.dag = self
@@ -1257,7 +1242,6 @@ class DAG(TaskSDKDag, LoggingMixin):
altered = set_state(
tasks=tasks_to_set_state,
- logical_date=logical_date,
run_id=run_id,
upstream=upstream,
downstream=downstream,
@@ -1280,26 +1264,37 @@ class DAG(TaskSDKDag, LoggingMixin):
include_upstream=False,
)
- if logical_date is None:
- dag_run = session.scalars(
- select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id ==
self.dag_id)
- ).one() # Raises an error if not found
- resolve_logical_date = dag_run.logical_date
- else:
- resolve_logical_date = logical_date
-
- end_date = resolve_logical_date if not future else None
- start_date = resolve_logical_date if not past else None
-
- subdag.clear(
- start_date=start_date,
- end_date=end_date,
- only_failed=True,
- session=session,
- # Exclude the task itself from being cleared
- exclude_task_ids=frozenset({task_id}),
- )
-
+ # Raises an error if not found
+ dr_id, logical_date = session.execute(
+ select(DagRun.id, DagRun.logical_date).where(
+ DagRun.run_id == run_id, DagRun.dag_id == self.dag_id
+ )
+ ).one()
+
+ # Now we want to clear downstreams of tasks that had their state set...
+ clear_kwargs = {
+ "only_failed": True,
+ "session": session,
+ # Exclude the task itself from being cleared.
+ "exclude_task_ids": frozenset((task_id,)),
+ }
+ if not future and not past: # Simple case 1: we're only dealing with
exactly one run.
+ clear_kwargs["run_id"] = run_id
+ subdag.clear(**clear_kwargs)
+ elif future and past: # Simple case 2: we're clearing ALL runs.
+ subdag.clear(**clear_kwargs)
+ else: # Complex cases: we may have more than one run, based on a date
range.
+ # Make 'future' and 'past' make some sense when multiple runs exist
+ # for the same logical date. We order runs by their id and only
+ # clear runs have larger/smaller ids.
+ exclude_run_id_stmt =
select(DagRun.run_id).where(DagRun.logical_date == logical_date)
+ if future:
+ clear_kwargs["start_date"] = logical_date
+ exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id >
dr_id)
+ else:
+ clear_kwargs["end_date"] = logical_date
+ exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id <
dr_id)
+
subdag.clear(exclude_run_ids=frozenset(session.scalars(exclude_run_id_stmt)),
**clear_kwargs)
return altered
@provide_session
@@ -1307,7 +1302,6 @@ class DAG(TaskSDKDag, LoggingMixin):
self,
*,
group_id: str,
- logical_date: datetime | None = None,
run_id: str | None = None,
state: TaskInstanceState,
upstream: bool = False,
@@ -1321,7 +1315,6 @@ class DAG(TaskSDKDag, LoggingMixin):
Set TaskGroup to the given state and clear downstream tasks in failed
or upstream_failed state.
:param group_id: The group_id of the TaskGroup
- :param logical_date: Logical date of the TaskInstance
:param run_id: The run_id of the TaskInstance
:param state: State to set the TaskInstance to
:param upstream: Include all upstream tasks of the given task_id
@@ -1333,23 +1326,9 @@ class DAG(TaskSDKDag, LoggingMixin):
"""
from airflow.api.common.mark_tasks import set_state
- if not exactly_one(logical_date, run_id):
- raise ValueError("Exactly one of logical_date or run_id must be
provided")
-
tasks_to_set_state: list[BaseOperator | tuple[BaseOperator, int]] = []
task_ids: list[str] = []
- if logical_date is None:
- dag_run = session.scalars(
- select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id ==
self.dag_id)
- ).one() # Raises an error if not found
- resolve_logical_date = dag_run.logical_date
- else:
- resolve_logical_date = logical_date
-
- end_date = resolve_logical_date if not future else None
- start_date = resolve_logical_date if not past else None
-
task_group_dict = self.task_group.get_task_group_dict()
task_group = task_group_dict.get(group_id)
if task_group is None:
@@ -1357,18 +1336,25 @@ class DAG(TaskSDKDag, LoggingMixin):
tasks_to_set_state = [task for task in task_group.iter_tasks() if
isinstance(task, BaseOperator)]
task_ids = [task.task_id for task in task_group.iter_tasks()]
dag_runs_query = select(DagRun.id).where(DagRun.dag_id == self.dag_id)
- if start_date is None and end_date is None:
- dag_runs_query = dag_runs_query.where(DagRun.logical_date ==
start_date)
- else:
- if start_date is not None:
- dag_runs_query = dag_runs_query.where(DagRun.logical_date >=
start_date)
- if end_date is not None:
- dag_runs_query = dag_runs_query.where(DagRun.logical_date <=
end_date)
+
+ @cache
+ def get_logical_date() -> datetime:
+ stmt = select(DagRun.logical_date).where(DagRun.run_id == run_id,
DagRun.dag_id == self.dag_id)
+ return session.scalars(stmt).one() # Raises an error if not found
+
+ end_date = None if future else get_logical_date()
+ start_date = None if past else get_logical_date()
+
+ if future:
+ dag_runs_query = dag_runs_query.where(DagRun.logical_date <=
start_date)
+ if past:
+ dag_runs_query = dag_runs_query.where(DagRun.logical_date >=
end_date)
+ if not future and not past:
+ dag_runs_query = dag_runs_query.where(DagRun.run_id == run_id)
with lock_rows(dag_runs_query, session):
altered = set_state(
tasks=tasks_to_set_state,
- logical_date=logical_date,
run_id=run_id,
upstream=upstream,
downstream=downstream,
@@ -1416,6 +1402,7 @@ class DAG(TaskSDKDag, LoggingMixin):
session: Session = NEW_SESSION,
dag_bag: DagBag | None = None,
exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None =
frozenset(),
+ exclude_run_ids: frozenset[str] | None = frozenset(),
) -> list[TaskInstance]: ... # pragma: no cover
@overload
@@ -1433,12 +1420,15 @@ class DAG(TaskSDKDag, LoggingMixin):
session: Session = NEW_SESSION,
dag_bag: DagBag | None = None,
exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None =
frozenset(),
+ exclude_run_ids: frozenset[str] | None = frozenset(),
) -> int: ... # pragma: no cover
@provide_session
def clear(
self,
task_ids: Collection[str | tuple[str, int]] | None = None,
+ *,
+ run_id: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
only_failed: bool = False,
@@ -1449,7 +1439,8 @@ class DAG(TaskSDKDag, LoggingMixin):
session: Session = NEW_SESSION,
dag_bag: DagBag | None = None,
exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None =
frozenset(),
- ) -> int | list[TaskInstance]:
+ exclude_run_ids: frozenset[str] | None = frozenset(),
+ ) -> int | Iterable[TaskInstance]:
"""
Clear a set of task instances associated with the current dag for a
specified date range.
@@ -1466,6 +1457,7 @@ class DAG(TaskSDKDag, LoggingMixin):
:param dag_bag: The DagBag used to find the dags (Optional)
:param exclude_task_ids: A set of ``task_id`` or (``task_id``,
``map_index``)
tuples that should not be cleared
+ :param exclude_run_ids: A set of ``run_id`` or (``run_id``)
"""
state: list[TaskInstanceState] = []
if only_failed:
@@ -1478,12 +1470,13 @@ class DAG(TaskSDKDag, LoggingMixin):
task_ids=task_ids,
start_date=start_date,
end_date=end_date,
- run_id=None,
+ run_id=run_id,
state=state,
include_dependent_dags=True,
session=session,
dag_bag=dag_bag,
exclude_task_ids=exclude_task_ids,
+ exclude_run_ids=exclude_run_ids,
)
if dry_run:
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index a2327221ad5..b535d8729ea 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -591,31 +591,21 @@ class DagRun(Base, LoggingMixin):
@classmethod
@provide_session
- def find_duplicate(
- cls,
- dag_id: str,
- run_id: str,
- logical_date: datetime,
- session: Session = NEW_SESSION,
- ) -> DagRun | None:
+ def find_duplicate(cls, dag_id: str, run_id: str, *, session: Session =
NEW_SESSION) -> DagRun | None:
"""
- Return an existing run for the DAG with a specific run_id or logical
date.
+ Return an existing run for the DAG with a specific run_id.
+
+ *None* is returned if no such DAG run is found.
:param dag_id: the dag_id to find duplicates for
:param run_id: defines the run id for this dag run
- :param logical_date: the logical date
:param session: database session
"""
- return session.scalars(
- select(cls).where(
- cls.dag_id == dag_id,
- or_(cls.run_id == run_id, cls.logical_date == logical_date),
- )
- ).one_or_none()
+ return session.scalars(select(cls).where(cls.dag_id == dag_id,
cls.run_id == run_id)).one_or_none()
@staticmethod
def generate_run_id(run_type: DagRunType, logical_date: datetime) -> str:
- """Generate Run ID based on Run Type and Logical Date."""
+ """Generate Run ID based on Run Type and logical Date."""
# _Ensure_ run_type is a DagRunType, not just a string from user code
return DagRunType(run_type).generate_run_id(logical_date)
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 805a746fba5..e58055b9f1c 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1417,7 +1417,10 @@ class Airflow(AirflowBaseView):
logger.info("Retrieving rendered templates.")
dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id)
- dag_run = dag.get_dagrun(logical_date=dttm)
+ dag_run = dag.get_dagrun(
+ select(DagRun.run_id).where(DagRun.logical_date ==
dttm).order_by(DagRun.id.desc()).limit(1),
+ session=session,
+ )
raw_task = dag.get_task(task_id).prepare_for_execution()
no_dagrun = False
@@ -1550,10 +1553,11 @@ class Airflow(AirflowBaseView):
dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id)
task = dag.get_task(task_id)
+ run_id = session.scalar(
+ select(DagRun.run_id).where(DagRun.logical_date ==
dttm).order_by(DagRun.id.desc()).limit(1)
+ )
dag_run = dag.get_dagrun(
- run_id=session.scalar(
- select(DagRun.run_id).where(DagRun.logical_date ==
dttm).order_by(DagRun.id.desc()).limit(1)
- ),
+ run_id=run_id,
session=session,
)
ti = dag_run.get_task_instance(task_id=task.task_id,
map_index=map_index, session=session)
@@ -2144,7 +2148,7 @@ class Airflow(AirflowBaseView):
form=form,
)
- dr = DagRun.find_duplicate(dag_id=dag_id, run_id=run_id,
logical_date=logical_date)
+ dr = DagRun.find_duplicate(dag_id=dag_id, run_id=run_id,
session=session)
if dr:
if dr.run_id == run_id:
message = f"The run ID {run_id} already exists"
@@ -2408,7 +2412,7 @@ class Airflow(AirflowBaseView):
only_failed = request.form.get("only_failed") == "true"
dag = get_airflow_app().dag_bag.get_dag(dag_id)
- dr = dag.get_dagrun(run_id=dag_run_id)
+ dr = dag.get_dagrun(run_id=dag_run_id, session=session)
start_date = dr.logical_date
end_date = dr.logical_date
@@ -3060,8 +3064,16 @@ class Airflow(AirflowBaseView):
flash(f'DAG "{dag_id}" seems to be missing from DagBag.', "error")
return redirect(url_for("Airflow.index"))
dt_nr_dr_data = get_date_time_num_runs_dag_runs_form_data(request,
session, dag)
- dttm = dt_nr_dr_data["dttm"]
- dag_run = dag.get_dagrun(logical_date=dttm)
+ run_id = session.scalar(
+ select(DagRun.run_id)
+ .where(DagRun.logical_date == dt_nr_dr_data["dttm"])
+ .order_by(DagRun.id.desc())
+ .limit(1)
+ )
+ dag_run = dag.get_dagrun(
+ run_id=run_id,
+ session=session,
+ )
dag_run_id = dag_run.run_id if dag_run else None
kwargs = {
@@ -3136,7 +3148,13 @@ class Airflow(AirflowBaseView):
dag = get_airflow_app().dag_bag.get_dag(dag_id, session=session)
dt_nr_dr_data = get_date_time_num_runs_dag_runs_form_data(request,
session, dag)
dttm = dt_nr_dr_data["dttm"]
- dag_run = dag.get_dagrun(logical_date=dttm)
+ run_id = session.scalar(
+ select(DagRun.run_id).where(DagRun.logical_date ==
dttm).order_by(DagRun.id.desc()).limit(1)
+ )
+ dag_run = dag.get_dagrun(
+ run_id=run_id,
+ session=session,
+ )
dag_run_id = dag_run.run_id if dag_run else None
kwargs = {**sanitize_args(request.args), "dag_id": dag_id, "tab":
"gantt", "dag_run_id": dag_run_id}
diff --git a/newsfragments/42404.significant.rst
b/newsfragments/42404.significant.rst
new file mode 100644
index 00000000000..47546b76ffa
--- /dev/null
+++ b/newsfragments/42404.significant.rst
@@ -0,0 +1,6 @@
+Removed ``logical_date`` arguments from functions and APIs for DAG run lookups
to align with Airflow 3.0.
+
+The shift towards using ``run_id`` as the sole identifier for DAG runs
eliminates the limitations of ``execution_date`` and ``logical_date``,
particularly for dynamic DAG runs and cases where multiple runs occur at the
same logical time. This change impacts database models, templates, and
functions:
+
+- Removed ``logical_date`` arguments from public APIs and Python functions
related to DAG run lookups.
+- ``run_id`` is now the exclusive identifier for DAG runs in these contexts.
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index a14a0be33a5..92ad62b6788 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -1721,6 +1721,7 @@ class
TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
@mock.patch("airflow.models.dag.DAG.set_task_instance_state")
def test_should_assert_call_mocked_api(self, mock_set_task_instance_state,
session):
self.create_task_instances(session)
+ run_id = "TEST_DAG_RUN_ID"
mock_set_task_instance_state.return_value = (
session.query(TaskInstance)
.join(TaskInstance.dag_run)
@@ -1734,7 +1735,7 @@ class
TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
json={
"dry_run": True,
"task_id": "print_the_context",
- "logical_date": DEFAULT_DATETIME_1.isoformat(),
+ "dag_run_id": run_id,
"include_upstream": True,
"include_downstream": True,
"include_future": True,
@@ -1757,8 +1758,7 @@ class
TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
mock_set_task_instance_state.assert_called_once_with(
commit=False,
downstream=True,
- run_id=None,
- logical_date=DEFAULT_DATETIME_1,
+ run_id=run_id,
future=True,
past=True,
state="failed",
@@ -1807,7 +1807,6 @@ class
TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
commit=False,
downstream=True,
run_id=run_id,
- logical_date=None,
future=True,
past=True,
state="failed",
@@ -1820,7 +1819,7 @@ class
TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
"error, code, payload",
[
[
- "{'_schema': ['Exactly one of logical_date or dag_run_id must
be provided']}",
+ "{'dag_run_id': ['Missing data for required field.']}",
400,
{
"dry_run": True,
@@ -1833,9 +1832,8 @@ class
TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
},
],
[
- "Task instance not found for task 'print_the_context' on
logical_date "
- "2021-01-01 00:00:00+00:00",
- 404,
+ "{'dag_run_id': ['Missing data for required field.'],
'logical_date': ['Unknown field.']}",
+ 400,
{
"dry_run": True,
"task_id": "print_the_context",
@@ -1862,7 +1860,7 @@ class
TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
},
],
[
- "{'_schema': ['Exactly one of logical_date or dag_run_id must
be provided']}",
+ "{'logical_date': ['Unknown field.']}",
400,
{
"dry_run": True,
@@ -1928,7 +1926,7 @@ class
TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
json={
"dry_run": True,
"task_id": "print_the_context",
- "logical_date": DEFAULT_DATETIME_1.isoformat(),
+ "dag_run_id": "random_run_id",
"include_upstream": True,
"include_downstream": True,
"include_future": True,
@@ -1941,14 +1939,14 @@ class
TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
@mock.patch("airflow.models.dag.DAG.set_task_instance_state")
def test_should_raise_not_found_if_run_id_is_wrong(self,
mock_set_task_instance_state, session):
self.create_task_instances(session)
- date = DEFAULT_DATETIME_1 + dt.timedelta(days=1)
+ run_id = "random_run_id"
response = self.client.post(
"/api/v1/dags/example_python_operator/updateTaskInstancesState",
environ_overrides={"REMOTE_USER": "test"},
json={
"dry_run": True,
"task_id": "print_the_context",
- "logical_date": date.isoformat(),
+ "dag_run_id": run_id,
"include_upstream": True,
"include_downstream": True,
"include_future": True,
@@ -1958,7 +1956,7 @@ class
TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
)
assert response.status_code == 404
assert response.json["detail"] == (
- f"Task instance not found for task 'print_the_context' on
logical_date {date}"
+ f"Task instance not found for task 'print_the_context' on DAG run
with ID '{run_id}'"
)
assert mock_set_task_instance_state.call_count == 0
@@ -1969,7 +1967,7 @@ class
TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
json={
"dry_run": True,
"task_id": "INVALID_TASK",
- "logical_date": DEFAULT_DATETIME_1.isoformat(),
+ "dag_run_id": "TEST_DAG_RUN_ID",
"include_upstream": True,
"include_downstream": True,
"include_future": True,
@@ -1979,48 +1977,6 @@ class
TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
)
assert response.status_code == 404
- @pytest.mark.parametrize(
- "payload, expected",
- [
- (
- {
- "dry_run": True,
- "task_id": "print_the_context",
- "logical_date": "2020-11-10T12:42:39.442973",
- "include_upstream": True,
- "include_downstream": True,
- "include_future": True,
- "include_past": True,
- "new_state": "failed",
- },
- "Naive datetime is disallowed",
- ),
- (
- {
- "dry_run": True,
- "task_id": "print_the_context",
- "logical_date": "2020-11-10T12:4opfo",
- "include_upstream": True,
- "include_downstream": True,
- "include_future": True,
- "include_past": True,
- "new_state": "failed",
- },
- "{'logical_date': ['Not a valid datetime.']}",
- ),
- ],
- )
- @provide_session
- def test_should_raise_400_for_naive_and_bad_datetime(self, payload,
expected, session):
- self.create_task_instances(session)
- response = self.client.post(
- "/api/v1/dags/example_python_operator/updateTaskInstancesState",
- environ_overrides={"REMOTE_USER": "test"},
- json=payload,
- )
- assert response.status_code == 400
- assert response.json["detail"] == expected
-
class TestPatchTaskInstance(TestTaskInstanceEndpoint):
ENDPOINT_URL = (
diff --git a/tests/api_connexion/schemas/test_task_instance_schema.py
b/tests/api_connexion/schemas/test_task_instance_schema.py
index 5297830dca0..9080572314f 100644
--- a/tests/api_connexion/schemas/test_task_instance_schema.py
+++ b/tests/api_connexion/schemas/test_task_instance_schema.py
@@ -166,7 +166,7 @@ class TestSetTaskInstanceStateFormSchema:
current_input = {
"dry_run": True,
"task_id": "print_the_context",
- "logical_date": "2020-01-01T00:00:00+00:00",
+ "dag_run_id": "test_run_id",
"include_upstream": True,
"include_downstream": True,
"include_future": True,
@@ -178,7 +178,7 @@ class TestSetTaskInstanceStateFormSchema:
result = set_task_instance_state_form.load(self.current_input)
expected_result = {
"dry_run": True,
- "logical_date": dt.datetime(2020, 1, 1, 0, 0,
tzinfo=dt.timezone(dt.timedelta(0), "+0000")),
+ "dag_run_id": "test_run_id",
"include_downstream": True,
"include_future": True,
"include_past": True,
@@ -194,7 +194,7 @@ class TestSetTaskInstanceStateFormSchema:
result = set_task_instance_state_form.load(self.current_input)
expected_result = {
"dry_run": True,
- "logical_date": dt.datetime(2020, 1, 1, 0, 0,
tzinfo=dt.timezone(dt.timedelta(0), "+0000")),
+ "dag_run_id": "test_run_id",
"include_downstream": True,
"include_future": True,
"include_past": True,
diff --git a/tests/cli/commands/test_task_command.py
b/tests/cli/commands/test_task_command.py
index ca8cadd9368..f78b53c5a76 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -34,7 +34,6 @@ from unittest.mock import sentinel
import pendulum
import pytest
-import sqlalchemy.exc
from airflow.cli import cli_parser
from airflow.cli.commands import task_command
@@ -228,18 +227,29 @@ class TestCliTasks:
assert ti.xcom_pull(ti.task_id) == new_file_path.as_posix()
@mock.patch("airflow.cli.commands.task_command.select")
- @mock.patch("sqlalchemy.orm.session.Session.scalars")
- @mock.patch("airflow.cli.commands.task_command.DagRun")
- def test_task_render_with_custom_timetable(self, mock_dagrun,
mock_scalars, mock_select):
+ @mock.patch("sqlalchemy.orm.session.Session.scalar")
+ def test_task_render_with_custom_timetable(self, mock_scalar, mock_select):
"""
- when calling `tasks render` on dag with custom timetable, the DagRun
object should be created with
- data_intervals.
+ Test that the `tasks render` CLI command queries the database correctly
+ for a DAG with a custom timetable. Verifies that a query is executed to
+ fetch the appropriate DagRun and that the database interaction occurs
as expected.
"""
- mock_scalars.side_effect = sqlalchemy.exc.NoResultFound
+ from sqlalchemy import select
+
+ from airflow.models.dagrun import DagRun
+
+ mock_query = (
+ select(DagRun).where(DagRun.dag_id ==
"example_workday_timetable").order_by(DagRun.id.desc())
+ )
+ mock_select.return_value = mock_query
+
+ mock_scalar.return_value = None
+
task_command.task_render(
self.parser.parse_args(["tasks", "render",
"example_workday_timetable", "run_this", "2022-01-01"])
)
- assert "data_interval" in mock_dagrun.call_args.kwargs
+
+ mock_select.assert_called_once()
@pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
def test_test_with_existing_dag_run(self, caplog):
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index f2128b205b4..1f946d027df 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -2613,13 +2613,10 @@ class TestDagDecorator:
@pytest.mark.parametrize(
- "run_id, logical_date",
- [
- (None, datetime_tz(2020, 1, 1)),
- ("test-run-id", None),
- ],
+ "run_id",
+ ["test-run-id"],
)
-def test_set_task_instance_state(run_id, logical_date, session, dag_maker):
+def test_set_task_instance_state(run_id, session, dag_maker):
"""Test that set_task_instance_state updates the TaskInstance state and
clear downstream failed"""
start_date = datetime_tz(2020, 1, 1)
@@ -2633,7 +2630,6 @@ def test_set_task_instance_state(run_id, logical_date,
session, dag_maker):
dagrun = dag_maker.create_dagrun(
run_id=run_id,
- logical_date=logical_date,
state=State.FAILED,
run_type=DagRunType.SCHEDULED,
)
@@ -2654,12 +2650,12 @@ def test_set_task_instance_state(run_id, logical_date,
session, dag_maker):
get_ti_from_db(task_3).state = State.UPSTREAM_FAILED
get_ti_from_db(task_4).state = State.FAILED
get_ti_from_db(task_5).state = State.SKIPPED
+
session.flush()
altered = dag.set_task_instance_state(
task_id=task_1.task_id,
run_id=run_id,
- logical_date=logical_date,
state=State.SUCCESS,
session=session,
)
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 39d6d56ef37..e26fe620c0b 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -220,10 +220,9 @@ class TestDagRun:
session.commit()
- assert DagRun.find_duplicate(dag_id=dag_id, run_id=dag_id,
logical_date=now) is not None
- assert DagRun.find_duplicate(dag_id=dag_id, run_id=dag_id,
logical_date=None) is not None
- assert DagRun.find_duplicate(dag_id=dag_id, run_id=None,
logical_date=now) is not None
- assert DagRun.find_duplicate(dag_id=dag_id, run_id=None,
logical_date=None) is None
+ assert DagRun.find_duplicate(dag_id=dag_id, run_id=dag_id) is not None
+ assert DagRun.find_duplicate(dag_id=dag_id, run_id=dag_id) is not None
+ assert DagRun.find_duplicate(dag_id=dag_id, run_id=None) is None
def test_dagrun_success_when_all_skipped(self, session):
"""
diff --git a/tests/operators/test_trigger_dagrun.py
b/tests/operators/test_trigger_dagrun.py
index ad1fca15ea4..85daeaed275 100644
--- a/tests/operators/test_trigger_dagrun.py
+++ b/tests/operators/test_trigger_dagrun.py
@@ -215,12 +215,14 @@ class TestDagRunOperator:
def test_trigger_dagrun_with_scheduled_dag_run(self, dag_maker):
"""Test TriggerDagRunOperator with custom logical_date and scheduled
dag_run."""
utc_now = timezone.utcnow()
+ run_id = f"scheduled__{utc_now.isoformat()}"
with dag_maker(
TEST_DAG_ID, default_args={"owner": "airflow", "start_date":
DEFAULT_DATE}, serialized=True
) as dag:
task = TriggerDagRunOperator(
task_id="test_trigger_dagrun_with_logical_date",
trigger_dag_id=TRIGGERED_DAG_ID,
+ trigger_run_id=run_id,
logical_date=utc_now,
poke_interval=1,
reset_dag_run=True,
@@ -496,23 +498,6 @@ class TestDagRunOperator:
triggered_dag_run = dagruns[1]
assert triggered_dag_run.state == State.QUEUED
- def test_trigger_dagrun_triggering_itself_with_logical_date(self,
dag_maker):
- """Test TriggerDagRunOperator that triggers itself with logical date,
- fails with DagRunAlreadyExists"""
- logical_date = DEFAULT_DATE
- with dag_maker(
- TEST_DAG_ID, default_args={"owner": "airflow", "start_date":
DEFAULT_DATE}, serialized=True
- ) as dag:
- task = TriggerDagRunOperator(
- task_id="test_task",
- trigger_dag_id=TEST_DAG_ID,
- logical_date=logical_date,
- )
- self.re_sync_triggered_dag_to_db(dag, dag_maker)
- dag_maker.create_dagrun()
- with pytest.raises(DagRunAlreadyExists):
- task.run(start_date=logical_date, end_date=logical_date)
-
@pytest.mark.skip_if_database_isolation_mode # Known to be broken in db
isolation mode
def test_trigger_dagrun_with_wait_for_completion_true_defer_false(self,
dag_maker):
"""Test TriggerDagRunOperator with wait_for_completion."""
diff --git a/tests/sensors/test_external_task_sensor.py
b/tests/sensors/test_external_task_sensor.py
index 4cd7b5e5f8c..e03ceeed019 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -1295,9 +1295,10 @@ def run_tasks(
start_date=logical_date,
run_type=DagRunType.MANUAL,
session=session,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+ data_interval=(logical_date, logical_date),
triggered_by=triggered_by,
)
+ runs[dag.dag_id] = dagrun
# we use sorting by task_id here because for the test DAG structure of
ours
# this is equivalent to topological sort. It would not work in general
case
# but it works for our case because we specifically constructed test
DAGS
@@ -1373,9 +1374,22 @@ def
test_external_task_marker_clear_activate(dag_bag_parent_child, session):
run_tasks(dag_bag, logical_date=day_1)
run_tasks(dag_bag, logical_date=day_2)
+ from sqlalchemy import select
+
+ run_ids = []
# Assert that dagruns of all the affected dags are set to SUCCESS before
tasks are cleared.
for dag, logical_date in itertools.product(dag_bag.dags.values(), [day_1,
day_2]):
- dagrun = dag.get_dagrun(logical_date=logical_date, session=session)
+ run_id = (
+ select(DagRun.run_id)
+ .where(DagRun.logical_date == logical_date)
+ .order_by(DagRun.id.desc())
+ .limit(1)
+ )
+ run_ids.append(run_id)
+ dagrun = dag.get_dagrun(
+ run_id=run_id,
+ session=session,
+ )
dagrun.set_state(State.SUCCESS)
session.flush()
@@ -1385,10 +1399,10 @@ def
test_external_task_marker_clear_activate(dag_bag_parent_child, session):
# Assert that dagruns of all the affected dags are set to QUEUED after
tasks are cleared.
# Unaffected dagruns should be left as SUCCESS.
- dagrun_0_1 =
dag_bag.get_dag("parent_dag_0").get_dagrun(logical_date=day_1, session=session)
- dagrun_0_2 =
dag_bag.get_dag("parent_dag_0").get_dagrun(logical_date=day_2, session=session)
- dagrun_1_1 = dag_bag.get_dag("child_dag_1").get_dagrun(logical_date=day_1,
session=session)
- dagrun_1_2 = dag_bag.get_dag("child_dag_1").get_dagrun(logical_date=day_2,
session=session)
+ dagrun_0_1 = dag_bag.get_dag("parent_dag_0").get_dagrun(run_id=run_ids[0],
session=session)
+ dagrun_0_2 = dag_bag.get_dag("parent_dag_0").get_dagrun(run_id=run_ids[1],
session=session)
+ dagrun_1_1 = dag_bag.get_dag("child_dag_1").get_dagrun(run_id=run_ids[2],
session=session)
+ dagrun_1_2 = dag_bag.get_dag("child_dag_1").get_dagrun(run_id=run_ids[3],
session=session)
assert dagrun_0_1.state == State.QUEUED
assert dagrun_0_2.state == State.QUEUED