This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d7b440bf431 SQLA2: fix mypy issue with getting the dialect name 
(#56941)
d7b440bf431 is described below

commit d7b440bf4318454a1d7472f77c3b29d9576e0983
Author: Dev-iL <[email protected]>
AuthorDate: Tue Oct 21 23:38:02 2025 +0300

    SQLA2: fix mypy issue with getting the dialect name (#56941)
    
    * SQLA: add util to get the dialect name type-safely
    
    * Replace unsafe dialect retrievals with `get_dialect_name`
---
 .../api_fastapi/core_api/services/ui/calendar.py   |  3 +-
 airflow-core/src/airflow/assets/manager.py         |  3 +-
 .../src/airflow/dag_processing/collection.py       |  4 +-
 airflow-core/src/airflow/models/dagrun.py          | 11 ++++-
 airflow-core/src/airflow/models/deadline.py        |  4 +-
 airflow-core/src/airflow/models/serialized_dag.py  |  7 +--
 airflow-core/src/airflow/models/trigger.py         |  4 +-
 airflow-core/src/airflow/utils/sqlalchemy.py       |  9 +++-
 airflow-core/tests/unit/utils/test_sqlalchemy.py   | 57 ++++++++++++++++------
 9 files changed, 73 insertions(+), 29 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py 
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py
index 912029e4845..eed47b2c4db 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py
@@ -39,6 +39,7 @@ from airflow.serialization.serialized_objects import 
SerializedDAG
 from airflow.timetables._cron import CronMixin
 from airflow.timetables.base import DataInterval, TimeRestriction
 from airflow.timetables.simple import ContinuousTimetable
+from airflow.utils.sqlalchemy import get_dialect_name
 
 log = structlog.get_logger(logger_name=__name__)
 
@@ -92,7 +93,7 @@ class CalendarService:
         granularity: Literal["hourly", "daily"],
     ) -> tuple[list[CalendarTimeRangeResponse], Sequence[Row]]:
         """Get historical DAG runs from the database."""
-        dialect = session.bind.dialect.name
+        dialect = get_dialect_name(session)
 
         time_expression = 
self._get_time_truncation_expression(DagRun.logical_date, granularity, dialect)
 
diff --git a/airflow-core/src/airflow/assets/manager.py 
b/airflow-core/src/airflow/assets/manager.py
index a00c7cae27d..3ab4f31e380 100644
--- a/airflow-core/src/airflow/assets/manager.py
+++ b/airflow-core/src/airflow/assets/manager.py
@@ -38,6 +38,7 @@ from airflow.models.asset import (
 )
 from airflow.stats import Stats
 from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.sqlalchemy import get_dialect_name
 
 if TYPE_CHECKING:
     from sqlalchemy.orm.session import Session
@@ -245,7 +246,7 @@ class AssetManager(LoggingMixin):
         if not dags_to_queue:
             return
 
-        if session.bind.dialect.name == "postgresql":
+        if get_dialect_name(session) == "postgresql":
             return cls._postgres_queue_dagruns(asset_id, dags_to_queue, 
session)
         return cls._slow_path_queue_dagruns(asset_id, dags_to_queue, session)
 
diff --git a/airflow-core/src/airflow/dag_processing/collection.py 
b/airflow-core/src/airflow/dag_processing/collection.py
index 8a9bba79738..80432a87966 100644
--- a/airflow-core/src/airflow/dag_processing/collection.py
+++ b/airflow-core/src/airflow/dag_processing/collection.py
@@ -60,7 +60,7 @@ from airflow.serialization.enums import Encoding
 from airflow.serialization.serialized_objects import BaseSerialization, 
LazyDeserializedDAG, SerializedDAG
 from airflow.triggers.base import BaseEventTrigger
 from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
-from airflow.utils.sqlalchemy import with_row_locks
+from airflow.utils.sqlalchemy import get_dialect_name, with_row_locks
 from airflow.utils.types import DagRunType
 
 if TYPE_CHECKING:
@@ -756,7 +756,7 @@ class AssetModelOperation(NamedTuple):
         there's a conflict. The scheduler makes a more comprehensive pass
         through all assets in ``_update_asset_orphanage``.
         """
-        if session.bind is not None and (dialect_name := 
session.bind.dialect.name) == "postgresql":
+        if (dialect_name := get_dialect_name(session)) == "postgresql":
             from sqlalchemy.dialects.postgresql import insert as 
postgresql_insert
 
             stmt: Any = postgresql_insert(AssetActive).on_conflict_do_nothing()
diff --git a/airflow-core/src/airflow/models/dagrun.py 
b/airflow-core/src/airflow/models/dagrun.py
index b4ae52f07d0..43d7f3803bd 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -79,7 +79,14 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.retries import retry_db_transaction
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.span_status import SpanStatus
-from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, mapped_column, 
nulls_first, with_row_locks
+from airflow.utils.sqlalchemy import (
+    ExtendedJSON,
+    UtcDateTime,
+    get_dialect_name,
+    mapped_column,
+    nulls_first,
+    with_row_locks,
+)
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.strings import get_random_string
 from airflow.utils.thread_safe_dict import ThreadSafeDict
@@ -399,7 +406,7 @@ class DagRun(Base, LoggingMixin):
     @duration.expression  # type: ignore[no-redef]
     @provide_session
     def duration(cls, session: Session = NEW_SESSION) -> Case:
-        dialect_name = session.bind.dialect.name
+        dialect_name = get_dialect_name(session)
         if dialect_name == "mysql":
             return func.timestampdiff(text("SECOND"), cls.start_date, 
cls.end_date)
 
diff --git a/airflow-core/src/airflow/models/deadline.py 
b/airflow-core/src/airflow/models/deadline.py
index c3845857915..2078e16181f 100644
--- a/airflow-core/src/airflow/models/deadline.py
+++ b/airflow-core/src/airflow/models/deadline.py
@@ -40,7 +40,7 @@ from airflow.stats import Stats
 from airflow.triggers.deadline import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY, 
DeadlineCallbackTrigger
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, mapped_column
+from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name, 
mapped_column
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
@@ -411,7 +411,7 @@ class ReferenceModels:
             dag_id = kwargs["dag_id"]
 
             # Get database dialect to use appropriate time difference 
calculation
-            dialect = getattr(session.bind.dialect, "name", None)
+            dialect = get_dialect_name(session)
 
             # Create database-specific expression for calculating duration in 
seconds
             if dialect == "postgresql":
diff --git a/airflow-core/src/airflow/models/serialized_dag.py 
b/airflow-core/src/airflow/models/serialized_dag.py
index 7067a89c787..7184a29effe 100644
--- a/airflow-core/src/airflow/models/serialized_dag.py
+++ b/airflow-core/src/airflow/models/serialized_dag.py
@@ -49,7 +49,7 @@ from airflow.serialization.serialized_objects import 
LazyDeserializedDAG, Serial
 from airflow.settings import COMPRESS_SERIALIZED_DAGS, json
 from airflow.utils.hashlib_wrapper import md5
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, mapped_column
+from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name, 
mapped_column
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
@@ -591,12 +591,13 @@ class SerializedDagModel(Base):
         """
         load_json: Callable | None
         if COMPRESS_SERIALIZED_DAGS is False:
-            if session.bind.dialect.name in ["sqlite", "mysql"]:
+            dialect = get_dialect_name(session)
+            if dialect in ["sqlite", "mysql"]:
                 data_col_to_select = func.json_extract(cls._data, 
"$.dag.dag_dependencies")
 
                 def load_json(deps_data):
                     return json.loads(deps_data) if deps_data else []
-            elif session.bind.dialect.name == "postgresql":
+            elif dialect == "postgresql":
                 # Use #> operator which works for both JSON and JSONB types
                 # Returns the JSON sub-object at the specified path
                 data_col_to_select = 
cls._data.op("#>")(literal('{"dag","dag_dependencies"}'))
diff --git a/airflow-core/src/airflow/models/trigger.py 
b/airflow-core/src/airflow/models/trigger.py
index 44a04a4123e..a0d93210870 100644
--- a/airflow-core/src/airflow/models/trigger.py
+++ b/airflow-core/src/airflow/models/trigger.py
@@ -37,7 +37,7 @@ from airflow.models.taskinstance import TaskInstance
 from airflow.triggers.base import BaseTaskEndEvent
 from airflow.utils.retries import run_with_db_retries
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, mapped_column, with_row_locks
+from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name, 
mapped_column, with_row_locks
 from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
@@ -231,7 +231,7 @@ class Trigger(Base):
             .group_by(cls.id)
             .having(func.count(TaskInstance.trigger_id) == 0)
         )
-        if session.bind.dialect.name == "mysql":
+        if get_dialect_name(session) == "mysql":
             # MySQL doesn't support DELETE with JOIN, so we need to do it in 
two steps
             ids = session.scalars(ids).all()
         session.execute(
diff --git a/airflow-core/src/airflow/utils/sqlalchemy.py 
b/airflow-core/src/airflow/utils/sqlalchemy.py
index 98bb68843e8..f22709cd410 100644
--- a/airflow-core/src/airflow/utils/sqlalchemy.py
+++ b/airflow-core/src/airflow/utils/sqlalchemy.py
@@ -58,6 +58,13 @@ except ImportError:
         return Column(*args, **kwargs)
 
 
+def get_dialect_name(session: Session) -> str | None:
+    """Safely get the name of the dialect associated with the given session."""
+    if (bind := session.get_bind()) is None:
+        raise ValueError("No bind/engine is associated with the provided 
Session")
+    return getattr(bind.dialect, "name", None)
+
+
 class UtcDateTime(TypeDecorator):
     """
     Similar to :class:`~sqlalchemy.types.TIMESTAMP` with ``timezone=True`` 
option, with some differences.
@@ -312,7 +319,7 @@ def nulls_first(col, session: Session) -> dict[str, Any]:
     Other databases do not need it since NULL values are considered lower than
     any other values, and appear first when the order is ASC (ascending).
     """
-    if session.bind.dialect.name == "postgresql":
+    if get_dialect_name(session) == "postgresql":
         return nullsfirst(col)
     return col
 
diff --git a/airflow-core/tests/unit/utils/test_sqlalchemy.py 
b/airflow-core/tests/unit/utils/test_sqlalchemy.py
index 051c91f05b1..c0a4761a492 100644
--- a/airflow-core/tests/unit/utils/test_sqlalchemy.py
+++ b/airflow-core/tests/unit/utils/test_sqlalchemy.py
@@ -21,7 +21,6 @@ import datetime
 import pickle
 from copy import deepcopy
 from unittest import mock
-from unittest.mock import MagicMock
 
 import pytest
 from kubernetes.client import models as k8s
@@ -37,6 +36,7 @@ from airflow.settings import Session
 from airflow.utils.sqlalchemy import (
     ExecutorConfigType,
     ensure_pod_is_valid_after_unpickling,
+    get_dialect_name,
     is_sqlalchemy_v1,
     prohibit_commit,
     with_row_locks,
@@ -52,13 +52,40 @@ pytestmark = pytest.mark.db_test
 TEST_POD = 
k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")]))
 
 
+class TestGetDialectName:
+    def test_returns_dialect_name_when_present(self, mocker):
+        mock_session = mocker.Mock()
+        mock_bind = mocker.Mock()
+        mock_bind.dialect.name = "postgresql"
+        mock_session.get_bind.return_value = mock_bind
+
+        assert get_dialect_name(mock_session) == "postgresql"
+
+    def test_raises_when_no_bind(self, mocker):
+        mock_session = mocker.Mock()
+        mock_session.get_bind.return_value = None
+
+        with pytest.raises(ValueError, match="No bind/engine is associated"):
+            get_dialect_name(mock_session)
+
+    def test_returns_none_when_dialect_has_no_name(self, mocker):
+        mock_session = mocker.Mock()
+        mock_bind = mocker.Mock()
+        # simulate dialect object without `name` attribute
+        mock_bind.dialect = mock.Mock()
+        delattr(mock_bind.dialect, "name") if hasattr(mock_bind.dialect, 
"name") else None
+        mock_session.get_bind.return_value = mock_bind
+
+        assert get_dialect_name(mock_session) is None
+
+
 class TestSqlAlchemyUtils:
     def setup_method(self):
         session = Session()
 
         # make sure NOT to run in UTC. Only postgres supports storing
         # timezone information in the datetime field
-        if session.bind.dialect.name == "postgresql":
+        if get_dialect_name(session) == "postgresql":
             session.execute(text("SET timezone='Europe/Amsterdam'"))
 
         self.session = session
@@ -124,7 +151,7 @@ class TestSqlAlchemyUtils:
         dag.clear()
 
     @pytest.mark.parametrize(
-        "dialect, supports_for_update_of, use_row_level_lock_conf, 
expected_use_row_level_lock",
+        ("dialect", "supports_for_update_of", "use_row_level_lock_conf", 
"expected_use_row_level_lock"),
         [
             ("postgresql", True, True, True),
             ("postgresql", True, False, False),
@@ -192,7 +219,7 @@ class TestSqlAlchemyUtils:
 
 class TestExecutorConfigType:
     @pytest.mark.parametrize(
-        "input, expected",
+        ("input", "expected"),
         [
             ("anything", "anything"),
             (
@@ -206,13 +233,13 @@ class TestExecutorConfigType:
             ),
         ],
     )
-    def test_bind_processor(self, input, expected):
+    def test_bind_processor(self, input, expected, mocker):
         """
         The returned bind processor should pickle the object as is, unless it 
is a dictionary with
         a pod_override node, in which case it should run it through 
BaseSerialization.
         """
         config_type = ExecutorConfigType()
-        mock_dialect = MagicMock()
+        mock_dialect = mocker.MagicMock()
         mock_dialect.dbapi = None
         process = config_type.bind_processor(mock_dialect)
         assert pickle.loads(process(input)) == expected
@@ -239,13 +266,13 @@ class TestExecutorConfigType:
             ),
         ],
     )
-    def test_result_processor(self, input):
+    def test_result_processor(self, input, mocker):
         """
         The returned bind processor should pickle the object as is, unless it 
is a dictionary with
         a pod_override node whose value was serialized with BaseSerialization.
         """
         config_type = ExecutorConfigType()
-        mock_dialect = MagicMock()
+        mock_dialect = mocker.MagicMock()
         mock_dialect.dbapi = None
         process = config_type.result_processor(mock_dialect, None)
         result = process(input)
@@ -277,7 +304,7 @@ class TestExecutorConfigType:
         assert instance.compare_values(a, a) is False
         assert instance.compare_values("a", "a") is True
 
-    def test_result_processor_bad_pickled_obj(self):
+    def test_result_processor_bad_pickled_obj(self, mocker):
         """
         If unpickled obj is missing attrs that curr lib expects
         """
@@ -309,7 +336,7 @@ class TestExecutorConfigType:
 
         # get the result processor method
         config_type = ExecutorConfigType()
-        mock_dialect = MagicMock()
+        mock_dialect = mocker.MagicMock()
         mock_dialect.dbapi = None
         process = config_type.result_processor(mock_dialect, None)
 
@@ -322,13 +349,13 @@ class TestExecutorConfigType:
 
 
 @pytest.mark.parametrize(
-    "mock_version, expected_result",
+    ("mock_version", "expected_result"),
     [
         ("1.0.0", True),  # Test 1: v1 identified as v1
         ("2.3.4", False),  # Test 2: v2 not identified as v1
     ],
 )
-def test_is_sqlalchemy_v1(mock_version, expected_result):
-    with mock.patch("airflow.utils.sqlalchemy.metadata") as mock_metadata:
-        mock_metadata.version.return_value = mock_version
-        assert is_sqlalchemy_v1() == expected_result
+def test_is_sqlalchemy_v1(mock_version, expected_result, mocker):
+    mock_metadata = mocker.patch("airflow.utils.sqlalchemy.metadata")
+    mock_metadata.version.return_value = mock_version
+    assert is_sqlalchemy_v1() == expected_result

Reply via email to