This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5c0fca6440 A couple of minor cleanups (#31890)
5c0fca6440 is described below
commit 5c0fca6440fae3ece915b365e1f06eb30db22d81
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Jun 28 18:52:07 2023 +0800
A couple of minor cleanups (#31890)
---
airflow/models/dag.py | 85 +++++++++++++++++++++++----------------------
airflow/utils/sqlalchemy.py | 21 +++++++++--
2 files changed, 61 insertions(+), 45 deletions(-)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 892ff5cef4..22c6e69418 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -119,7 +119,14 @@ from airflow.utils.decorators import
fixup_decorator_warning_stack
from airflow.utils.helpers import at_most_one, exactly_one, validate_key
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked,
tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import (
+ Interval,
+ UtcDateTime,
+ lock_rows,
+ skip_locked,
+ tuple_in_condition,
+ with_row_locks,
+)
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
@@ -2003,7 +2010,6 @@ class DAG(LoggingMixin):
tasks_to_set_state: list[BaseOperator | tuple[BaseOperator, int]] = []
task_ids: list[str] = []
- locked_dag_run_ids: list[int] = []
if execution_date is None:
dag_run = session.scalars(
@@ -2022,57 +2028,52 @@ class DAG(LoggingMixin):
raise ValueError("TaskGroup {group_id} could not be found")
tasks_to_set_state = [task for task in task_group.iter_tasks() if
isinstance(task, BaseOperator)]
task_ids = [task.task_id for task in task_group.iter_tasks()]
- dag_runs_query = session.query(DagRun.id).where(DagRun.dag_id ==
self.dag_id).with_for_update()
+ dag_runs_query = session.query(DagRun.id).where(DagRun.dag_id ==
self.dag_id)
if start_date is None and end_date is None:
dag_runs_query = dag_runs_query.where(DagRun.execution_date ==
start_date)
else:
if start_date is not None:
dag_runs_query = dag_runs_query.where(DagRun.execution_date >=
start_date)
-
if end_date is not None:
dag_runs_query = dag_runs_query.where(DagRun.execution_date <=
end_date)
- locked_dag_run_ids = dag_runs_query.all()
-
- altered = set_state(
- tasks=tasks_to_set_state,
- execution_date=execution_date,
- run_id=run_id,
- upstream=upstream,
- downstream=downstream,
- future=future,
- past=past,
- state=state,
- commit=commit,
- session=session,
- )
-
- if not commit:
- del locked_dag_run_ids
- return altered
-
- # Clear downstream tasks that are in failed/upstream_failed state to
resume them.
- # Flush the session so that the tasks marked success are reflected in
the db.
- session.flush()
- task_subset = self.partial_subset(
- task_ids_or_regex=task_ids,
- include_downstream=True,
- include_upstream=False,
- )
+ with lock_rows(dag_runs_query, session):
+ altered = set_state(
+ tasks=tasks_to_set_state,
+ execution_date=execution_date,
+ run_id=run_id,
+ upstream=upstream,
+ downstream=downstream,
+ future=future,
+ past=past,
+ state=state,
+ commit=commit,
+ session=session,
+ )
+ if not commit:
+ return altered
+
+ # Clear downstream tasks that are in failed/upstream_failed state
to resume them.
+ # Flush the session so that the tasks marked success are reflected
in the db.
+ session.flush()
+ task_subset = self.partial_subset(
+ task_ids_or_regex=task_ids,
+ include_downstream=True,
+ include_upstream=False,
+ )
- task_subset.clear(
- start_date=start_date,
- end_date=end_date,
- include_subdags=True,
- include_parentdag=True,
- only_failed=True,
- session=session,
- # Exclude the task from the current group from being cleared
- exclude_task_ids=frozenset(task_ids),
- )
+ task_subset.clear(
+ start_date=start_date,
+ end_date=end_date,
+ include_subdags=True,
+ include_parentdag=True,
+ only_failed=True,
+ session=session,
+ # Exclude the task from the current group from being cleared
+ exclude_task_ids=frozenset(task_ids),
+ )
- del locked_dag_run_ids
return altered
@property
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index f12f6f44c8..32a5a796d5 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -17,18 +17,18 @@
# under the License.
from __future__ import annotations
+import contextlib
import copy
import datetime
import json
import logging
-from typing import TYPE_CHECKING, Any, Iterable
+from typing import TYPE_CHECKING, Any, Generator, Iterable
import pendulum
from dateutil import relativedelta
from sqlalchemy import TIMESTAMP, PickleType, and_, event, false, nullsfirst,
or_, true, tuple_
from sqlalchemy.dialects import mssql, mysql
from sqlalchemy.exc import OperationalError
-from sqlalchemy.orm.session import Session
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.expression import ColumnOperators
from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText
@@ -39,6 +39,7 @@ from airflow.serialization.enums import Encoding
if TYPE_CHECKING:
from kubernetes.client.models.v1_pod import V1Pod
+ from sqlalchemy.orm import Query, Session
log = logging.getLogger(__name__)
@@ -411,7 +412,7 @@ def nulls_first(col, session: Session) -> dict[str, Any]:
USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler",
"use_row_level_locking", fallback=True)
-def with_row_locks(query, session: Session, **kwargs):
+def with_row_locks(query: Query, session: Session, **kwargs) -> Query:
"""
Apply with_for_update to an SQLAlchemy query, if row level locking is in
use.
@@ -429,6 +430,20 @@ def with_row_locks(query, session: Session, **kwargs):
return query
[email protected]
+def lock_rows(query: Query, session: Session) -> Generator[None, None, None]:
+ """Lock database rows during the context manager block.
+
+ This is a convenient method for ``with_row_locks`` when we don't need the
+ locked rows.
+
+ :meta private:
+ """
+ locked_rows = with_row_locks(query, session).all()
+ yield
+ del locked_rows
+
+
class CommitProhibitorGuard:
"""Context manager class that powers prohibit_commit."""