This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c0fca6440 A couple of minor cleanups (#31890)
5c0fca6440 is described below

commit 5c0fca6440fae3ece915b365e1f06eb30db22d81
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Jun 28 18:52:07 2023 +0800

    A couple of minor cleanups (#31890)
---
 airflow/models/dag.py       | 85 +++++++++++++++++++++++----------------------
 airflow/utils/sqlalchemy.py | 21 +++++++++--
 2 files changed, 61 insertions(+), 45 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 892ff5cef4..22c6e69418 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -119,7 +119,14 @@ from airflow.utils.decorators import 
fixup_decorator_warning_stack
 from airflow.utils.helpers import at_most_one, exactly_one, validate_key
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, 
tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import (
+    Interval,
+    UtcDateTime,
+    lock_rows,
+    skip_locked,
+    tuple_in_condition,
+    with_row_locks,
+)
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
 
@@ -2003,7 +2010,6 @@ class DAG(LoggingMixin):
 
         tasks_to_set_state: list[BaseOperator | tuple[BaseOperator, int]] = []
         task_ids: list[str] = []
-        locked_dag_run_ids: list[int] = []
 
         if execution_date is None:
             dag_run = session.scalars(
@@ -2022,57 +2028,52 @@ class DAG(LoggingMixin):
             raise ValueError("TaskGroup {group_id} could not be found")
         tasks_to_set_state = [task for task in task_group.iter_tasks() if 
isinstance(task, BaseOperator)]
         task_ids = [task.task_id for task in task_group.iter_tasks()]
-        dag_runs_query = session.query(DagRun.id).where(DagRun.dag_id == 
self.dag_id).with_for_update()
 
+        dag_runs_query = session.query(DagRun.id).where(DagRun.dag_id == 
self.dag_id)
         if start_date is None and end_date is None:
             dag_runs_query = dag_runs_query.where(DagRun.execution_date == 
start_date)
         else:
             if start_date is not None:
                 dag_runs_query = dag_runs_query.where(DagRun.execution_date >= 
start_date)
-
             if end_date is not None:
                 dag_runs_query = dag_runs_query.where(DagRun.execution_date <= 
end_date)
 
-        locked_dag_run_ids = dag_runs_query.all()
-
-        altered = set_state(
-            tasks=tasks_to_set_state,
-            execution_date=execution_date,
-            run_id=run_id,
-            upstream=upstream,
-            downstream=downstream,
-            future=future,
-            past=past,
-            state=state,
-            commit=commit,
-            session=session,
-        )
-
-        if not commit:
-            del locked_dag_run_ids
-            return altered
-
-        # Clear downstream tasks that are in failed/upstream_failed state to 
resume them.
-        # Flush the session so that the tasks marked success are reflected in 
the db.
-        session.flush()
-        task_subset = self.partial_subset(
-            task_ids_or_regex=task_ids,
-            include_downstream=True,
-            include_upstream=False,
-        )
+        with lock_rows(dag_runs_query, session):
+            altered = set_state(
+                tasks=tasks_to_set_state,
+                execution_date=execution_date,
+                run_id=run_id,
+                upstream=upstream,
+                downstream=downstream,
+                future=future,
+                past=past,
+                state=state,
+                commit=commit,
+                session=session,
+            )
+            if not commit:
+                return altered
+
+            # Clear downstream tasks that are in failed/upstream_failed state 
to resume them.
+            # Flush the session so that the tasks marked success are reflected 
in the db.
+            session.flush()
+            task_subset = self.partial_subset(
+                task_ids_or_regex=task_ids,
+                include_downstream=True,
+                include_upstream=False,
+            )
 
-        task_subset.clear(
-            start_date=start_date,
-            end_date=end_date,
-            include_subdags=True,
-            include_parentdag=True,
-            only_failed=True,
-            session=session,
-            # Exclude the task from the current group from being cleared
-            exclude_task_ids=frozenset(task_ids),
-        )
+            task_subset.clear(
+                start_date=start_date,
+                end_date=end_date,
+                include_subdags=True,
+                include_parentdag=True,
+                only_failed=True,
+                session=session,
+                # Exclude the task from the current group from being cleared
+                exclude_task_ids=frozenset(task_ids),
+            )
 
-        del locked_dag_run_ids
         return altered
 
     @property
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index f12f6f44c8..32a5a796d5 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -17,18 +17,18 @@
 # under the License.
 from __future__ import annotations
 
+import contextlib
 import copy
 import datetime
 import json
 import logging
-from typing import TYPE_CHECKING, Any, Iterable
+from typing import TYPE_CHECKING, Any, Generator, Iterable
 
 import pendulum
 from dateutil import relativedelta
 from sqlalchemy import TIMESTAMP, PickleType, and_, event, false, nullsfirst, 
or_, true, tuple_
 from sqlalchemy.dialects import mssql, mysql
 from sqlalchemy.exc import OperationalError
-from sqlalchemy.orm.session import Session
 from sqlalchemy.sql import ColumnElement
 from sqlalchemy.sql.expression import ColumnOperators
 from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText
@@ -39,6 +39,7 @@ from airflow.serialization.enums import Encoding
 
 if TYPE_CHECKING:
     from kubernetes.client.models.v1_pod import V1Pod
+    from sqlalchemy.orm import Query, Session
 
 log = logging.getLogger(__name__)
 
@@ -411,7 +412,7 @@ def nulls_first(col, session: Session) -> dict[str, Any]:
 USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", 
"use_row_level_locking", fallback=True)
 
 
-def with_row_locks(query, session: Session, **kwargs):
+def with_row_locks(query: Query, session: Session, **kwargs) -> Query:
     """
     Apply with_for_update to an SQLAlchemy query, if row level locking is in 
use.
 
@@ -429,6 +430,20 @@ def with_row_locks(query, session: Session, **kwargs):
         return query
 
 
[email protected]
+def lock_rows(query: Query, session: Session) -> Generator[None, None, None]:
+    """Lock database rows during the context manager block.
+
+    This is a convenient method for ``with_row_locks`` when we don't need the
+    locked rows.
+
+    :meta private:
+    """
+    locked_rows = with_row_locks(query, session).all()
+    yield
+    del locked_rows
+
+
 class CommitProhibitorGuard:
     """Context manager class that powers prohibit_commit."""
 

Reply via email to