This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d7b440bf431 SQLA2: fix mypy issue with getting the dialect name
(#56941)
d7b440bf431 is described below
commit d7b440bf4318454a1d7472f77c3b29d9576e0983
Author: Dev-iL <[email protected]>
AuthorDate: Tue Oct 21 23:38:02 2025 +0300
SQLA2: fix mypy issue with getting the dialect name (#56941)
* SQLA: add util to get the dialect name type-safely
* Replace unsafe dialect retrievals with `get_dialect_name`
---
.../api_fastapi/core_api/services/ui/calendar.py | 3 +-
airflow-core/src/airflow/assets/manager.py | 3 +-
.../src/airflow/dag_processing/collection.py | 4 +-
airflow-core/src/airflow/models/dagrun.py | 11 ++++-
airflow-core/src/airflow/models/deadline.py | 4 +-
airflow-core/src/airflow/models/serialized_dag.py | 7 +--
airflow-core/src/airflow/models/trigger.py | 4 +-
airflow-core/src/airflow/utils/sqlalchemy.py | 9 +++-
airflow-core/tests/unit/utils/test_sqlalchemy.py | 57 ++++++++++++++++------
9 files changed, 73 insertions(+), 29 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py
index 912029e4845..eed47b2c4db 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py
@@ -39,6 +39,7 @@ from airflow.serialization.serialized_objects import
SerializedDAG
from airflow.timetables._cron import CronMixin
from airflow.timetables.base import DataInterval, TimeRestriction
from airflow.timetables.simple import ContinuousTimetable
+from airflow.utils.sqlalchemy import get_dialect_name
log = structlog.get_logger(logger_name=__name__)
@@ -92,7 +93,7 @@ class CalendarService:
granularity: Literal["hourly", "daily"],
) -> tuple[list[CalendarTimeRangeResponse], Sequence[Row]]:
"""Get historical DAG runs from the database."""
- dialect = session.bind.dialect.name
+ dialect = get_dialect_name(session)
time_expression =
self._get_time_truncation_expression(DagRun.logical_date, granularity, dialect)
diff --git a/airflow-core/src/airflow/assets/manager.py
b/airflow-core/src/airflow/assets/manager.py
index a00c7cae27d..3ab4f31e380 100644
--- a/airflow-core/src/airflow/assets/manager.py
+++ b/airflow-core/src/airflow/assets/manager.py
@@ -38,6 +38,7 @@ from airflow.models.asset import (
)
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.sqlalchemy import get_dialect_name
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
@@ -245,7 +246,7 @@ class AssetManager(LoggingMixin):
if not dags_to_queue:
return
- if session.bind.dialect.name == "postgresql":
+ if get_dialect_name(session) == "postgresql":
return cls._postgres_queue_dagruns(asset_id, dags_to_queue,
session)
return cls._slow_path_queue_dagruns(asset_id, dags_to_queue, session)
diff --git a/airflow-core/src/airflow/dag_processing/collection.py
b/airflow-core/src/airflow/dag_processing/collection.py
index 8a9bba79738..80432a87966 100644
--- a/airflow-core/src/airflow/dag_processing/collection.py
+++ b/airflow-core/src/airflow/dag_processing/collection.py
@@ -60,7 +60,7 @@ from airflow.serialization.enums import Encoding
from airflow.serialization.serialized_objects import BaseSerialization,
LazyDeserializedDAG, SerializedDAG
from airflow.triggers.base import BaseEventTrigger
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
-from airflow.utils.sqlalchemy import with_row_locks
+from airflow.utils.sqlalchemy import get_dialect_name, with_row_locks
from airflow.utils.types import DagRunType
if TYPE_CHECKING:
@@ -756,7 +756,7 @@ class AssetModelOperation(NamedTuple):
there's a conflict. The scheduler makes a more comprehensive pass
through all assets in ``_update_asset_orphanage``.
"""
- if session.bind is not None and (dialect_name :=
session.bind.dialect.name) == "postgresql":
+ if (dialect_name := get_dialect_name(session)) == "postgresql":
from sqlalchemy.dialects.postgresql import insert as
postgresql_insert
stmt: Any = postgresql_insert(AssetActive).on_conflict_do_nothing()
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index b4ae52f07d0..43d7f3803bd 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -79,7 +79,14 @@ from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import retry_db_transaction
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.span_status import SpanStatus
-from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, mapped_column,
nulls_first, with_row_locks
+from airflow.utils.sqlalchemy import (
+ ExtendedJSON,
+ UtcDateTime,
+ get_dialect_name,
+ mapped_column,
+ nulls_first,
+ with_row_locks,
+)
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.strings import get_random_string
from airflow.utils.thread_safe_dict import ThreadSafeDict
@@ -399,7 +406,7 @@ class DagRun(Base, LoggingMixin):
@duration.expression # type: ignore[no-redef]
@provide_session
def duration(cls, session: Session = NEW_SESSION) -> Case:
- dialect_name = session.bind.dialect.name
+ dialect_name = get_dialect_name(session)
if dialect_name == "mysql":
return func.timestampdiff(text("SECOND"), cls.start_date,
cls.end_date)
diff --git a/airflow-core/src/airflow/models/deadline.py
b/airflow-core/src/airflow/models/deadline.py
index c3845857915..2078e16181f 100644
--- a/airflow-core/src/airflow/models/deadline.py
+++ b/airflow-core/src/airflow/models/deadline.py
@@ -40,7 +40,7 @@ from airflow.stats import Stats
from airflow.triggers.deadline import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY,
DeadlineCallbackTrigger
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, mapped_column
+from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name,
mapped_column
if TYPE_CHECKING:
from sqlalchemy.orm import Session
@@ -411,7 +411,7 @@ class ReferenceModels:
dag_id = kwargs["dag_id"]
# Get database dialect to use appropriate time difference
calculation
- dialect = getattr(session.bind.dialect, "name", None)
+ dialect = get_dialect_name(session)
# Create database-specific expression for calculating duration in
seconds
if dialect == "postgresql":
diff --git a/airflow-core/src/airflow/models/serialized_dag.py
b/airflow-core/src/airflow/models/serialized_dag.py
index 7067a89c787..7184a29effe 100644
--- a/airflow-core/src/airflow/models/serialized_dag.py
+++ b/airflow-core/src/airflow/models/serialized_dag.py
@@ -49,7 +49,7 @@ from airflow.serialization.serialized_objects import
LazyDeserializedDAG, Serial
from airflow.settings import COMPRESS_SERIALIZED_DAGS, json
from airflow.utils.hashlib_wrapper import md5
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, mapped_column
+from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name,
mapped_column
if TYPE_CHECKING:
from sqlalchemy.orm import Session
@@ -591,12 +591,13 @@ class SerializedDagModel(Base):
"""
load_json: Callable | None
if COMPRESS_SERIALIZED_DAGS is False:
- if session.bind.dialect.name in ["sqlite", "mysql"]:
+ dialect = get_dialect_name(session)
+ if dialect in ["sqlite", "mysql"]:
data_col_to_select = func.json_extract(cls._data,
"$.dag.dag_dependencies")
def load_json(deps_data):
return json.loads(deps_data) if deps_data else []
- elif session.bind.dialect.name == "postgresql":
+ elif dialect == "postgresql":
# Use #> operator which works for both JSON and JSONB types
# Returns the JSON sub-object at the specified path
data_col_to_select =
cls._data.op("#>")(literal('{"dag","dag_dependencies"}'))
diff --git a/airflow-core/src/airflow/models/trigger.py
b/airflow-core/src/airflow/models/trigger.py
index 44a04a4123e..a0d93210870 100644
--- a/airflow-core/src/airflow/models/trigger.py
+++ b/airflow-core/src/airflow/models/trigger.py
@@ -37,7 +37,7 @@ from airflow.models.taskinstance import TaskInstance
from airflow.triggers.base import BaseTaskEndEvent
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, mapped_column, with_row_locks
+from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name,
mapped_column, with_row_locks
from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
@@ -231,7 +231,7 @@ class Trigger(Base):
.group_by(cls.id)
.having(func.count(TaskInstance.trigger_id) == 0)
)
- if session.bind.dialect.name == "mysql":
+ if get_dialect_name(session) == "mysql":
# MySQL doesn't support DELETE with JOIN, so we need to do it in
two steps
ids = session.scalars(ids).all()
session.execute(
diff --git a/airflow-core/src/airflow/utils/sqlalchemy.py
b/airflow-core/src/airflow/utils/sqlalchemy.py
index 98bb68843e8..f22709cd410 100644
--- a/airflow-core/src/airflow/utils/sqlalchemy.py
+++ b/airflow-core/src/airflow/utils/sqlalchemy.py
@@ -58,6 +58,13 @@ except ImportError:
return Column(*args, **kwargs)
+def get_dialect_name(session: Session) -> str | None:
+ """Safely get the name of the dialect associated with the given session."""
+ if (bind := session.get_bind()) is None:
+ raise ValueError("No bind/engine is associated with the provided
Session")
+ return getattr(bind.dialect, "name", None)
+
+
class UtcDateTime(TypeDecorator):
"""
Similar to :class:`~sqlalchemy.types.TIMESTAMP` with ``timezone=True``
option, with some differences.
@@ -312,7 +319,7 @@ def nulls_first(col, session: Session) -> dict[str, Any]:
Other databases do not need it since NULL values are considered lower than
any other values, and appear first when the order is ASC (ascending).
"""
- if session.bind.dialect.name == "postgresql":
+ if get_dialect_name(session) == "postgresql":
return nullsfirst(col)
return col
diff --git a/airflow-core/tests/unit/utils/test_sqlalchemy.py
b/airflow-core/tests/unit/utils/test_sqlalchemy.py
index 051c91f05b1..c0a4761a492 100644
--- a/airflow-core/tests/unit/utils/test_sqlalchemy.py
+++ b/airflow-core/tests/unit/utils/test_sqlalchemy.py
@@ -21,7 +21,6 @@ import datetime
import pickle
from copy import deepcopy
from unittest import mock
-from unittest.mock import MagicMock
import pytest
from kubernetes.client import models as k8s
@@ -37,6 +36,7 @@ from airflow.settings import Session
from airflow.utils.sqlalchemy import (
ExecutorConfigType,
ensure_pod_is_valid_after_unpickling,
+ get_dialect_name,
is_sqlalchemy_v1,
prohibit_commit,
with_row_locks,
@@ -52,13 +52,40 @@ pytestmark = pytest.mark.db_test
TEST_POD =
k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")]))
+class TestGetDialectName:
+ def test_returns_dialect_name_when_present(self, mocker):
+ mock_session = mocker.Mock()
+ mock_bind = mocker.Mock()
+ mock_bind.dialect.name = "postgresql"
+ mock_session.get_bind.return_value = mock_bind
+
+ assert get_dialect_name(mock_session) == "postgresql"
+
+ def test_raises_when_no_bind(self, mocker):
+ mock_session = mocker.Mock()
+ mock_session.get_bind.return_value = None
+
+ with pytest.raises(ValueError, match="No bind/engine is associated"):
+ get_dialect_name(mock_session)
+
+ def test_returns_none_when_dialect_has_no_name(self, mocker):
+ mock_session = mocker.Mock()
+ mock_bind = mocker.Mock()
+ # simulate dialect object without `name` attribute
+ mock_bind.dialect = mock.Mock()
+ delattr(mock_bind.dialect, "name") if hasattr(mock_bind.dialect,
"name") else None
+ mock_session.get_bind.return_value = mock_bind
+
+ assert get_dialect_name(mock_session) is None
+
+
class TestSqlAlchemyUtils:
def setup_method(self):
session = Session()
# make sure NOT to run in UTC. Only postgres supports storing
# timezone information in the datetime field
- if session.bind.dialect.name == "postgresql":
+ if get_dialect_name(session) == "postgresql":
session.execute(text("SET timezone='Europe/Amsterdam'"))
self.session = session
@@ -124,7 +151,7 @@ class TestSqlAlchemyUtils:
dag.clear()
@pytest.mark.parametrize(
- "dialect, supports_for_update_of, use_row_level_lock_conf,
expected_use_row_level_lock",
+ ("dialect", "supports_for_update_of", "use_row_level_lock_conf",
"expected_use_row_level_lock"),
[
("postgresql", True, True, True),
("postgresql", True, False, False),
@@ -192,7 +219,7 @@ class TestSqlAlchemyUtils:
class TestExecutorConfigType:
@pytest.mark.parametrize(
- "input, expected",
+ ("input", "expected"),
[
("anything", "anything"),
(
@@ -206,13 +233,13 @@ class TestExecutorConfigType:
),
],
)
- def test_bind_processor(self, input, expected):
+ def test_bind_processor(self, input, expected, mocker):
"""
The returned bind processor should pickle the object as is, unless it
is a dictionary with
a pod_override node, in which case it should run it through
BaseSerialization.
"""
config_type = ExecutorConfigType()
- mock_dialect = MagicMock()
+ mock_dialect = mocker.MagicMock()
mock_dialect.dbapi = None
process = config_type.bind_processor(mock_dialect)
assert pickle.loads(process(input)) == expected
@@ -239,13 +266,13 @@ class TestExecutorConfigType:
),
],
)
- def test_result_processor(self, input):
+ def test_result_processor(self, input, mocker):
"""
The returned bind processor should pickle the object as is, unless it
is a dictionary with
a pod_override node whose value was serialized with BaseSerialization.
"""
config_type = ExecutorConfigType()
- mock_dialect = MagicMock()
+ mock_dialect = mocker.MagicMock()
mock_dialect.dbapi = None
process = config_type.result_processor(mock_dialect, None)
result = process(input)
@@ -277,7 +304,7 @@ class TestExecutorConfigType:
assert instance.compare_values(a, a) is False
assert instance.compare_values("a", "a") is True
- def test_result_processor_bad_pickled_obj(self):
+ def test_result_processor_bad_pickled_obj(self, mocker):
"""
If unpickled obj is missing attrs that curr lib expects
"""
@@ -309,7 +336,7 @@ class TestExecutorConfigType:
# get the result processor method
config_type = ExecutorConfigType()
- mock_dialect = MagicMock()
+ mock_dialect = mocker.MagicMock()
mock_dialect.dbapi = None
process = config_type.result_processor(mock_dialect, None)
@@ -322,13 +349,13 @@ class TestExecutorConfigType:
@pytest.mark.parametrize(
- "mock_version, expected_result",
+ ("mock_version", "expected_result"),
[
("1.0.0", True), # Test 1: v1 identified as v1
("2.3.4", False), # Test 2: v2 not identified as v1
],
)
-def test_is_sqlalchemy_v1(mock_version, expected_result):
- with mock.patch("airflow.utils.sqlalchemy.metadata") as mock_metadata:
- mock_metadata.version.return_value = mock_version
- assert is_sqlalchemy_v1() == expected_result
+def test_is_sqlalchemy_v1(mock_version, expected_result, mocker):
+ mock_metadata = mocker.patch("airflow.utils.sqlalchemy.metadata")
+ mock_metadata.version.return_value = mock_version
+ assert is_sqlalchemy_v1() == expected_result