This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d2e8dd49271 Fix MyPy errors in settings.py (#57368)
d2e8dd49271 is described below

commit d2e8dd49271707728c58959fd6058178fe3e6d39
Author: LI,JHE-CHEN <[email protected]>
AuthorDate: Mon Nov 3 09:35:32 2025 -0500

    Fix MyPy errors in settings.py (#57368)
---
 .../src/airflow/cli/commands/db_command.py         | 10 ++---
 .../src/airflow/jobs/scheduler_job_runner.py       |  3 +-
 airflow-core/src/airflow/models/taskinstance.py    |  5 ++-
 airflow-core/src/airflow/settings.py               | 52 +++++++++++++++++-----
 airflow-core/src/airflow/utils/db.py               |  6 +--
 .../src/airflow/utils/task_instance_session.py     |  2 +-
 airflow-core/tests/unit/models/test_dag.py         |  2 +-
 7 files changed, 56 insertions(+), 24 deletions(-)

diff --git a/airflow-core/src/airflow/cli/commands/db_command.py 
b/airflow-core/src/airflow/cli/commands/db_command.py
index ea3241320fd..f9ec7365f57 100644
--- a/airflow-core/src/airflow/cli/commands/db_command.py
+++ b/airflow-core/src/airflow/cli/commands/db_command.py
@@ -44,7 +44,7 @@ log = logging.getLogger(__name__)
 @providers_configuration_loaded
 def resetdb(args):
     """Reset the metadata database."""
-    print(f"DB: {settings.engine.url!r}")
+    print(f"DB: {settings.get_engine().url!r}")
     if not (args.yes or input("This will drop existing tables if they exist. 
Proceed? (y/n)").upper() == "Y"):
         raise SystemExit("Cancelled")
     db.resetdb(skip_init=args.skip_init)
@@ -94,7 +94,7 @@ def run_db_migrate_command(args, command, revision_heads_map: 
dict[str, str]):
 
     :meta private:
     """
-    print(f"DB: {settings.engine.url!r}")
+    print(f"DB: {settings.get_engine().url!r}")
     if args.to_revision and args.to_version:
         raise SystemExit("Cannot supply both `--to-revision` and 
`--to-version`.")
     if args.from_version and args.from_revision:
@@ -128,7 +128,7 @@ def run_db_migrate_command(args, command, 
revision_heads_map: dict[str, str]):
         to_revision = args.to_revision
 
     if not args.show_sql_only:
-        print(f"Performing upgrade to the metadata database 
{settings.engine.url!r}")
+        print(f"Performing upgrade to the metadata database 
{settings.get_engine().url!r}")
     else:
         print("Generating sql for upgrade -- upgrade commands will *not* be 
submitted.")
     command(
@@ -172,7 +172,7 @@ def run_db_downgrade_command(args, command, 
revision_heads_map: dict[str, str]):
     elif args.to_revision:
         to_revision = args.to_revision
     if not args.show_sql_only:
-        print(f"Performing downgrade with database {settings.engine.url!r}")
+        print(f"Performing downgrade with database 
{settings.get_engine().url!r}")
     else:
         print("Generating sql for downgrade -- downgrade commands will *not* 
be submitted.")
 
@@ -231,7 +231,7 @@ def _quote_mysql_password_for_cnf(password: str | None) -> 
str:
 @providers_configuration_loaded
 def shell(args):
     """Run a shell that allows to access metadata database."""
-    url = settings.engine.url
+    url = settings.get_engine().url
     print(f"DB: {url!r}")
 
     if url.get_backend_name() == "mysql":
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py 
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index ba026848f71..b86c12c9ff4 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -1089,7 +1089,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
 
             self._run_scheduler_loop()
 
-            settings.Session.remove()
+            if settings.Session is not None:
+                settings.Session.remove()
         except Exception:
             self.log.exception("Exception when executing 
SchedulerJob._run_scheduler_loop")
             raise
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index 49a4ddd95ec..81c44eb2b0b 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -1170,7 +1170,8 @@ class TaskInstance(Base, LoggingMixin):
 
         # Closing all pooled connections to prevent
         # "max number of connections reached"
-        settings.engine.dispose()
+        if settings.engine is not None:
+            settings.engine.dispose()
         if verbose:
             if mark_success:
                 cls.logger().info("Marking success for %s on %s", ti.task, 
ti.logical_date)
@@ -1638,7 +1639,7 @@ class TaskInstance(Base, LoggingMixin):
         """
         # Do not use provide_session here -- it expunges everything on exit!
         if not session:
-            session = settings.Session()
+            session = settings.get_session()()
 
         from airflow.exceptions import NotMapped
         from airflow.models.mappedoperator import get_mapped_ti_count
diff --git a/airflow-core/src/airflow/settings.py 
b/airflow-core/src/airflow/settings.py
index e2d96d96f52..503fb93acdf 100644
--- a/airflow-core/src/airflow/settings.py
+++ b/airflow-core/src/airflow/settings.py
@@ -31,8 +31,18 @@ from typing import TYPE_CHECKING, Any, Literal
 import pluggy
 from packaging.version import Version
 from sqlalchemy import create_engine
-from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as 
SAAsyncSession, create_async_engine
+from sqlalchemy.ext.asyncio import (
+    AsyncEngine,
+    AsyncSession as SAAsyncSession,
+    create_async_engine,
+)
 from sqlalchemy.orm import scoped_session, sessionmaker
+
+try:
+    from sqlalchemy.ext.asyncio import async_sessionmaker
+except ImportError:
+    async_sessionmaker = sessionmaker  # type: ignore[assignment,misc]
+
 from sqlalchemy.pool import NullPool
 
 from airflow import __version__ as airflow_version, policies
@@ -111,15 +121,30 @@ Mapping of sync scheme to async scheme.
 :meta private:
 """
 
-engine: Engine
-Session: scoped_session
+engine: Engine | None = None
+Session: scoped_session | None = None
 # NonScopedSession creates global sessions and is not safe to use in 
multi-threaded environment without
 # additional precautions. The only use case is when the session lifecycle needs
 # custom handling. Most of the time we only want one unique thread local 
session object,
 # this is achieved by the Session factory above.
-NonScopedSession: sessionmaker
-async_engine: AsyncEngine
-AsyncSession: Callable[..., SAAsyncSession]
+NonScopedSession: sessionmaker | None = None
+async_engine: AsyncEngine | None = None
+AsyncSession: Callable[..., SAAsyncSession] | None = None
+
+
+def get_engine():
+    """Get the configured engine, raising an error if not configured."""
+    if engine is None:
+        raise RuntimeError("Engine not configured. Call configure_orm() 
first.")
+    return engine
+
+
+def get_session():
+    """Get the configured Session, raising an error if not configured."""
+    if Session is None:
+        raise RuntimeError("Session not configured. Call configure_orm() 
first.")
+    return Session
+
 
 # The JSON library to use for DAG Serialization and De-Serialization
 json = json_lib
@@ -353,19 +378,22 @@ def _configure_async_session() -> None:
     this does not work well with Pytest and you can end up with issues when the
     session and runs in a different event loop from the test itself.
     """
-    global AsyncSession
-    global async_engine
+    global AsyncSession, async_engine
+
+    if not SQL_ALCHEMY_CONN_ASYNC:
+        async_engine = None
+        AsyncSession = None
+        return
 
     async_engine = create_async_engine(
         SQL_ALCHEMY_CONN_ASYNC,
         connect_args=_get_connect_args("async"),
         future=True,
     )
-    AsyncSession = sessionmaker(
+    AsyncSession = async_sessionmaker(
         bind=async_engine,
-        autocommit=False,
-        autoflush=False,
         class_=SAAsyncSession,
+        autoflush=False,
         expire_on_commit=False,
     )
 
@@ -420,6 +448,8 @@ def configure_orm(disable_connection_pool=False, 
pool_class=None):
             autoflush=False,
             expire_on_commit=False,
         )
+    if engine is None:
+        raise RuntimeError("Engine must be initialized before creating a 
session")
     NonScopedSession = _session_maker(engine)
     Session = scoped_session(NonScopedSession)
 
diff --git a/airflow-core/src/airflow/utils/db.py 
b/airflow-core/src/airflow/utils/db.py
index 50cceed94cf..d969bccc43b 100644
--- a/airflow-core/src/airflow/utils/db.py
+++ b/airflow-core/src/airflow/utils/db.py
@@ -828,7 +828,7 @@ def _configured_alembic_environment() -> 
Generator[EnvironmentContext, None, Non
             config,
             script,
         ) as env,
-        settings.engine.connect() as connection,
+        settings.get_engine().connect() as connection,
     ):
         alembic_logger = logging.getLogger("alembic")
         level = alembic_logger.level
@@ -1044,7 +1044,7 @@ def _revisions_above_min_for_offline(config, revisions) 
-> None:
     :param revisions: list of Alembic revision ids
     :return: None
     """
-    dbname = settings.engine.dialect.name
+    dbname = settings.get_engine().dialect.name
     if dbname == "sqlite":
         raise SystemExit("Offline migration not supported for SQLite.")
     min_version, min_revision = ("2.7.0", "937cbd173ca1")
@@ -1257,7 +1257,7 @@ def _handle_fab_downgrade(*, session: Session) -> None:
             fab_version,
         )
         return
-    connection = settings.engine.connect()
+    connection = settings.get_engine().connect()
     insp = inspect(connection)
     if not fab_version and insp.has_table("ab_user"):
         log.info(
diff --git a/airflow-core/src/airflow/utils/task_instance_session.py 
b/airflow-core/src/airflow/utils/task_instance_session.py
index bb9741bf525..019a752c773 100644
--- a/airflow-core/src/airflow/utils/task_instance_session.py
+++ b/airflow-core/src/airflow/utils/task_instance_session.py
@@ -41,7 +41,7 @@ def get_current_task_instance_session() -> Session:
             log.warning('File: "%s", %s , in %s', filename, line_number, name)
             if line:
                 log.warning("  %s", line.strip())
-        __current_task_instance_session = settings.Session()
+        __current_task_instance_session = settings.get_session()()
     return __current_task_instance_session
 
 
diff --git a/airflow-core/tests/unit/models/test_dag.py 
b/airflow-core/tests/unit/models/test_dag.py
index 3e341dfc528..7f7e30fa45a 100644
--- a/airflow-core/tests/unit/models/test_dag.py
+++ b/airflow-core/tests/unit/models/test_dag.py
@@ -1430,7 +1430,7 @@ my_postgres_conn:
         ) as dag:
             EmptyOperator(task_id=task_id)
 
-        session = settings.Session()
+        session = settings.get_session()()
         dagrun_1 = dag_maker.create_dagrun(
             run_id="backfill",
             run_type=DagRunType.BACKFILL_JOB,

Reply via email to