This is an automated email from the ASF dual-hosted git repository. bbovenzi pushed a commit to branch mapped-instance-actions in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 469092494da6b8baa6cfe145b76e40eaa495635e Author: Tzu-ping Chung <[email protected]> AuthorDate: Tue Apr 19 18:01:55 2022 +0800 Introduce tuple_().in_() shim for MSSQL compat --- airflow/api/common/mark_tasks.py | 5 +++-- airflow/models/dag.py | 8 ++++---- airflow/utils/sqlalchemy.py | 8 +++----- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index 349b935e82..1d4709fb82 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -20,7 +20,7 @@ from datetime import datetime from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union -from sqlalchemy import or_, tuple_ +from sqlalchemy import or_ from sqlalchemy.orm import contains_eager from sqlalchemy.orm.session import Session as SASession @@ -32,6 +32,7 @@ from airflow.operators.subdag import SubDagOperator from airflow.utils import timezone from airflow.utils.helpers import exactly_one from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.sqlalchemy import tuple_in_condition from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType @@ -203,7 +204,7 @@ def get_all_dag_task_query( if is_string_list: qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids)) else: - qry_dag = qry_dag.filter(tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids)) + qry_dag = qry_dag.filter(tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), task_ids)) qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options( contains_eager(TaskInstance.dag_run) ) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 9c93bcef13..83860ba591 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -52,7 +52,7 @@ import jinja2 import pendulum from dateutil.relativedelta import relativedelta from pendulum.tz.timezone import Timezone -from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, or_, tuple_ +from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, not_, or_ from sqlalchemy.orm import backref, joinedload, relationship from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session @@ -85,7 +85,7 @@ from airflow.utils.file import correct_maybe_zipped from airflow.utils.helpers import exactly_one, validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks +from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType @@ -1451,7 +1451,7 @@ class DAG(LoggingMixin): elif isinstance(next(iter(task_ids), None), str): tis = tis.filter(TI.task_id.in_(task_ids)) else: - tis = tis.filter(tuple_(TI.task_id, TI.map_index).in_(task_ids)) + tis = tis.filter(tuple_in_condition((TI.task_id, TI.map_index), task_ids)) # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC if end_date or not self.allow_future_exec_dates: @@ -1611,7 +1611,7 @@ class DAG(LoggingMixin): elif isinstance(next(iter(exclude_task_ids), None), str): tis = tis.filter(TI.task_id.notin_(exclude_task_ids)) else: - tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids)) + tis = tis.filter(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids))) return tis diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index de4ad01e69..5c36d826b2 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -19,11 +19,12 @@ import datetime import json import logging +from operator import and_, or_ from typing import Any, Dict, Iterable, Tuple import pendulum from dateutil import relativedelta -from sqlalchemy import and_, event, false, nullsfirst, or_, tuple_ +from sqlalchemy import event, nullsfirst, tuple_ from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import Session from sqlalchemy.sql import ColumnElement @@ -338,7 +339,4 @@ def tuple_in_condition( """ if settings.engine.dialect.name != "mssql": return tuple_(*columns).in_(collection) - clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in collection] - if not clauses: - return false() - return or_(*clauses) + return or_(*(and_(*(c == v for c, v in zip(columns, values))) for values in collection))
