This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3873230a11d Remove tuple_in_condition helpers (#45201)
3873230a11d is described below
commit 3873230a11de8b9cc24d012ecdfe6848bc6ae0cf
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Dec 26 11:04:00 2024 +0800
Remove tuple_in_condition helpers (#45201)
---
airflow/jobs/scheduler_job_runner.py | 35 ++++++--------
airflow/models/dag.py | 8 ++--
airflow/models/dagrun.py | 7 +--
airflow/models/skipmixin.py | 7 ++-
airflow/models/taskinstance.py | 16 +++----
airflow/utils/sqlalchemy.py | 54 ++--------------------
airflow/www/utils.py | 10 +---
.../providers/standard/utils/sensor_helper.py | 22 ++++-----
8 files changed, 44 insertions(+), 115 deletions(-)
diff --git a/airflow/jobs/scheduler_job_runner.py
b/airflow/jobs/scheduler_job_runner.py
index a0558e9040d..ece0d4c1cb6 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -34,7 +34,7 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
from deprecated import deprecated
-from sqlalchemy import and_, delete, exists, func, not_, select, text, update
+from sqlalchemy import and_, delete, exists, func, select, text, tuple_, update
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import lazyload, load_only, make_transient, selectinload
from sqlalchemy.sql import expression
@@ -77,12 +77,7 @@ from airflow.utils.event_scheduler import EventScheduler
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction,
run_with_db_retries
from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.sqlalchemy import (
- is_lock_not_available_error,
- prohibit_commit,
- tuple_in_condition,
- with_row_locks,
-)
+from airflow.utils.sqlalchemy import is_lock_not_available_error,
prohibit_commit, with_row_locks
from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -357,28 +352,25 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
.join(TI.dag_run)
.where(DR.state == DagRunState.RUNNING)
.join(TI.dag_model)
- .where(not_(DM.is_paused))
+ .where(~DM.is_paused)
.where(TI.state == TaskInstanceState.SCHEDULED)
.options(selectinload(TI.dag_model))
.order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
)
if starved_pools:
- query = query.where(not_(TI.pool.in_(starved_pools)))
+ query = query.where(TI.pool.not_in(starved_pools))
if starved_dags:
- query = query.where(not_(TI.dag_id.in_(starved_dags)))
+ query = query.where(TI.dag_id.not_in(starved_dags))
if starved_tasks:
- task_filter = tuple_in_condition((TI.dag_id, TI.task_id),
starved_tasks)
- query = query.where(not_(task_filter))
+ query = query.where(tuple_(TI.dag_id,
TI.task_id).not_in(starved_tasks))
if starved_tasks_task_dagrun_concurrency:
- task_filter = tuple_in_condition(
- (TI.dag_id, TI.run_id, TI.task_id),
- starved_tasks_task_dagrun_concurrency,
+ query = query.where(
+ tuple_(TI.dag_id, TI.run_id,
TI.task_id).not_in(starved_tasks_task_dagrun_concurrency)
)
- query = query.where(not_(task_filter))
query = query.limit(max_tis)
@@ -1314,9 +1306,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
existing_dagruns = (
session.execute(
select(DagRun.dag_id, DagRun.logical_date).where(
- tuple_in_condition(
- (DagRun.dag_id, DagRun.logical_date),
- ((dm.dag_id, dm.next_dagrun) for dm in dag_models),
+ tuple_(DagRun.dag_id, DagRun.logical_date).in_(
+ (dm.dag_id, dm.next_dagrun) for dm in dag_models
),
)
)
@@ -1402,7 +1393,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
existing_dagruns: set[tuple[str, timezone.DateTime]] = set(
session.execute(
select(DagRun.dag_id, DagRun.logical_date).where(
- tuple_in_condition((DagRun.dag_id, DagRun.logical_date),
logical_dates.items())
+ tuple_(DagRun.dag_id,
DagRun.logical_date).in_(logical_dates.items())
)
)
)
@@ -2188,7 +2179,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
if assets:
session.execute(
delete(AssetActive).where(
- tuple_in_condition((AssetActive.name, AssetActive.uri),
((a.name, a.uri) for a in assets))
+ tuple_(AssetActive.name, AssetActive.uri).in_((a.name,
a.uri) for a in assets)
)
)
Stats.gauge("asset.orphaned", len(assets))
@@ -2201,7 +2192,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
active_assets = set(
session.execute(
select(AssetActive.name, AssetActive.uri).where(
- tuple_in_condition((AssetActive.name, AssetActive.uri),
((a.name, a.uri) for a in assets))
+ tuple_(AssetActive.name, AssetActive.uri).in_((a.name,
a.uri) for a in assets)
)
)
)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index f2090649301..d127914a8c5 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -58,9 +58,9 @@ from sqlalchemy import (
and_,
case,
func,
- not_,
or_,
select,
+ tuple_,
update,
)
from sqlalchemy.ext.associationproxy import association_proxy
@@ -108,7 +108,7 @@ from airflow.utils import timezone
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, lock_rows,
tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, with_row_locks
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -1081,7 +1081,7 @@ class DAG(TaskSDKDag, LoggingMixin):
tis = tis.where(TaskInstance.state.in_(state))
if exclude_run_ids:
- tis = tis.where(not_(TaskInstance.run_id.in_(exclude_run_ids)))
+ tis = tis.where(TaskInstance.run_id.not_in(exclude_run_ids))
if include_dependent_dags:
# Recursively find external tasks indicated by ExternalTaskMarker
@@ -1192,7 +1192,7 @@ class DAG(TaskSDKDag, LoggingMixin):
elif isinstance(next(iter(exclude_task_ids), None), str):
tis = tis.where(TI.task_id.notin_(exclude_task_ids))
else:
- tis = tis.where(not_(tuple_in_condition((TI.task_id,
TI.map_index), exclude_task_ids)))
+ tis = tis.where(tuple_(TI.task_id,
TI.map_index).not_in(exclude_task_ids))
return tis
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 7278e88742e..a5bef7e589c 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -42,6 +42,7 @@ from sqlalchemy import (
not_,
or_,
text,
+ tuple_,
update,
)
from sqlalchemy.exc import IntegrityError
@@ -74,7 +75,7 @@ from airflow.utils.helpers import chunks, is_container,
prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import retry_db_transaction
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, nulls_first,
tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, with_row_locks
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import NOTSET, DagRunTriggeredByType, DagRunType
@@ -1644,7 +1645,7 @@ class DagRun(Base, LoggingMixin):
.where(
TI.dag_id == self.dag_id,
TI.run_id == self.run_id,
- tuple_in_condition((TI.task_id, TI.map_index),
schedulable_ti_ids_chunk),
+ tuple_(TI.task_id,
TI.map_index).in_(schedulable_ti_ids_chunk),
)
.values(
state=TaskInstanceState.SCHEDULED,
@@ -1668,7 +1669,7 @@ class DagRun(Base, LoggingMixin):
.where(
TI.dag_id == self.dag_id,
TI.run_id == self.run_id,
- tuple_in_condition((TI.task_id, TI.map_index),
dummy_ti_ids_chunk),
+ tuple_(TI.task_id,
TI.map_index).in_(dummy_ti_ids_chunk),
)
.values(
state=TaskInstanceState.SUCCESS,
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index ad5c5d01539..8b59043ecef 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -21,18 +21,17 @@ from collections.abc import Iterable, Sequence
from types import GeneratorType
from typing import TYPE_CHECKING
-from sqlalchemy import update
+from sqlalchemy import tuple_, update
from airflow.exceptions import AirflowException
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
- from sqlalchemy import Session
+ from sqlalchemy.orm import Session
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
@@ -74,7 +73,7 @@ class SkipMixin(LoggingMixin):
.where(
TaskInstance.dag_id == dag_run.dag_id,
TaskInstance.run_id == dag_run.run_id,
- tuple_in_condition((TaskInstance.task_id,
TaskInstance.map_index), tasks),
+ tuple_(TaskInstance.task_id,
TaskInstance.map_index).in_(tasks),
)
.values(state=TaskInstanceState.SKIPPED, start_date=now,
end_date=now)
.execution_options(synchronize_session=False)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 6ef4452834f..d519af41d90 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -55,12 +55,14 @@ from sqlalchemy import (
Text,
UniqueConstraint,
and_,
+ case,
delete,
extract,
false,
func,
inspect,
or_,
+ select,
text,
tuple_,
update,
@@ -71,7 +73,6 @@ from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import lazyload, reconstructor, relationship
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
-from sqlalchemy.sql.expression import case, select
from sqlalchemy_utils import UUIDType
from airflow import settings
@@ -131,12 +132,7 @@ from airflow.utils.operator_helpers import
ExecutionCallableRunner, context_to_a
from airflow.utils.platform import getuser
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.sqlalchemy import (
- ExecutorConfigType,
- ExtendedJSON,
- UtcDateTime,
- tuple_in_condition,
-)
+from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON,
UtcDateTime
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.task_instance_session import
set_current_task_instance_session
@@ -3497,7 +3493,7 @@ class TaskInstance(Base, LoggingMixin):
if task_id_only:
filters.append(cls.task_id.in_(task_id_only))
if with_map_index:
- filters.append(tuple_in_condition((cls.task_id, cls.map_index),
with_map_index))
+ filters.append(tuple_(cls.task_id,
cls.map_index).in_(with_map_index))
if not filters:
return false()
@@ -3675,8 +3671,8 @@ class TaskInstance(Base, LoggingMixin):
AssetUniqueKey(name, uri)
for name, uri in session.execute(
select(AssetActive.name, AssetActive.uri).where(
- tuple_in_condition(
- (AssetActive.name, AssetActive.uri),
[attrs.astuple(key) for key in asset_unique_keys]
+ tuple_(AssetActive.name, AssetActive.uri).in_(
+ attrs.astuple(key) for key in asset_unique_keys
)
)
)
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 23597f25a95..917af7c1f1d 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -23,7 +23,7 @@ import datetime
import logging
from collections.abc import Generator, Iterable
from importlib import metadata
-from typing import TYPE_CHECKING, Any, overload
+from typing import TYPE_CHECKING, Any
from packaging import version
from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst, tuple_
@@ -438,22 +438,6 @@ def is_lock_not_available_error(error: OperationalError):
return False
-@overload
-def tuple_in_condition(
- columns: tuple[ColumnElement, ...],
- collection: Iterable[Any],
-) -> ColumnOperators: ...
-
-
-@overload
-def tuple_in_condition(
- columns: tuple[ColumnElement, ...],
- collection: Select,
- *,
- session: Session,
-) -> ColumnOperators: ...
-
-
def tuple_in_condition(
columns: tuple[ColumnElement, ...],
collection: Iterable[Any] | Select,
@@ -463,46 +447,14 @@ def tuple_in_condition(
"""
Generate a tuple-in-collection operator to use in ``.where()``.
- For most SQL backends, this generates a simple ``([col, ...]) IN
[condition]``
- clause.
+ Kept for backward compatibility. Remove when providers drop support for
+ apache-airflow<3.0.
:meta private:
"""
return tuple_(*columns).in_(collection)
-@overload
-def tuple_not_in_condition(
- columns: tuple[ColumnElement, ...],
- collection: Iterable[Any],
-) -> ColumnOperators: ...
-
-
-@overload
-def tuple_not_in_condition(
- columns: tuple[ColumnElement, ...],
- collection: Select,
- *,
- session: Session,
-) -> ColumnOperators: ...
-
-
-def tuple_not_in_condition(
- columns: tuple[ColumnElement, ...],
- collection: Iterable[Any] | Select,
- *,
- session: Session | None = None,
-) -> ColumnOperators:
- """
- Generate a tuple-not-in-collection operator to use in ``.where()``.
-
- This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
-
- :meta private:
- """
- return tuple_(*columns).not_in(collection)
-
-
def get_orm_mapper():
"""Get the correct ORM mapper for the installed SQLAlchemy version."""
import sqlalchemy.orm.mapper
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 727139a9a6b..9c319424574 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -37,7 +37,7 @@ from markdown_it import MarkdownIt
from markupsafe import Markup
from pygments import highlight, lexers
from pygments.formatters import HtmlFormatter
-from sqlalchemy import delete, func, select, types
+from sqlalchemy import delete, func, select, tuple_, types
from sqlalchemy.ext.associationproxy import AssociationProxy
from airflow.api_fastapi.app import get_auth_manager
@@ -49,7 +49,6 @@ from airflow.utils import timezone
from airflow.utils.code_utils import get_python_source
from airflow.utils.helpers import alchemy_to_dict
from airflow.utils.json import WebEncoder
-from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import State, TaskInstanceState
from airflow.www.forms import DateTimeWithTimezoneField
from airflow.www.widgets import AirflowDateTimePickerWidget
@@ -867,12 +866,7 @@ class DagRunCustomSQLAInterface(CustomSQLAInterface):
def delete_all(self, items: list[Model]) -> bool:
self.session.execute(
- delete(TI).where(
- tuple_in_condition(
- (TI.dag_id, TI.run_id),
- ((x.dag_id, x.run_id) for x in items),
- )
- )
+ delete(TI).where(tuple_(TI.dag_id, TI.run_id).in_((x.dag_id,
x.run_id) for x in items))
)
return super().delete_all(items)
diff --git a/providers/src/airflow/providers/standard/utils/sensor_helper.py
b/providers/src/airflow/providers/standard/utils/sensor_helper.py
index 57d906da671..8c4524cba65 100644
--- a/providers/src/airflow/providers/standard/utils/sensor_helper.py
+++ b/providers/src/airflow/providers/standard/utils/sensor_helper.py
@@ -18,14 +18,14 @@ from __future__ import annotations
from typing import TYPE_CHECKING, cast
-from sqlalchemy import func, select
+from sqlalchemy import func, select, tuple_
from airflow.models import DagBag, DagRun, TaskInstance
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import tuple_in_condition
if TYPE_CHECKING:
- from sqlalchemy.orm import Query, Session
+ from sqlalchemy.orm import Session
+ from sqlalchemy.sql import Executable
@provide_session
@@ -55,9 +55,7 @@ def _get_count(
if external_task_ids:
count = (
session.scalar(
- _count_query(TI, states, dttm_filter, external_dag_id,
session).filter(
- TI.task_id.in_(external_task_ids)
- )
+ _count_stmt(TI, states, dttm_filter,
external_dag_id).where(TI.task_id.in_(external_task_ids))
)
) / len(external_task_ids)
elif external_task_group_id:
@@ -69,17 +67,17 @@ def _get_count(
else:
count = (
session.scalar(
- _count_query(TI, states, dttm_filter, external_dag_id,
session).filter(
- tuple_in_condition((TI.task_id, TI.map_index),
external_task_group_task_ids)
+ _count_stmt(TI, states, dttm_filter,
external_dag_id).where(
+ tuple_(TI.task_id,
TI.map_index).in_(external_task_group_task_ids)
)
)
) / len(external_task_group_task_ids)
else:
- count = session.scalar(_count_query(DR, states, dttm_filter,
external_dag_id, session))
+ count = session.scalar(_count_stmt(DR, states, dttm_filter,
external_dag_id))
return cast(int, count)
-def _count_query(model, states, dttm_filter, external_dag_id, session:
Session) -> Query:
+def _count_stmt(model, states, dttm_filter, external_dag_id) -> Executable:
"""
Get the count of records against dttm filter and states.
@@ -87,12 +85,10 @@ def _count_query(model, states, dttm_filter,
external_dag_id, session: Session)
:param states: task or dag states
:param dttm_filter: date time filter for logical date
:param external_dag_id: The ID of the external DAG.
- :param session: airflow session object
"""
- query = select(func.count()).filter(
+ return select(func.count()).where(
model.dag_id == external_dag_id, model.state.in_(states),
model.logical_date.in_(dttm_filter)
)
- return query
def _get_external_task_group_task_ids(dttm_filter, external_task_group_id,
external_dag_id, session):