This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d2e8dd49271 Fix MyPy errors in settings.py (#57368)
d2e8dd49271 is described below
commit d2e8dd49271707728c58959fd6058178fe3e6d39
Author: LI,JHE-CHEN <[email protected]>
AuthorDate: Mon Nov 3 09:35:32 2025 -0500
Fix MyPy errors in settings.py (#57368)
---
.../src/airflow/cli/commands/db_command.py | 10 ++---
.../src/airflow/jobs/scheduler_job_runner.py | 3 +-
airflow-core/src/airflow/models/taskinstance.py | 5 ++-
airflow-core/src/airflow/settings.py | 52 +++++++++++++++++-----
airflow-core/src/airflow/utils/db.py | 6 +--
.../src/airflow/utils/task_instance_session.py | 2 +-
airflow-core/tests/unit/models/test_dag.py | 2 +-
7 files changed, 56 insertions(+), 24 deletions(-)
diff --git a/airflow-core/src/airflow/cli/commands/db_command.py
b/airflow-core/src/airflow/cli/commands/db_command.py
index ea3241320fd..f9ec7365f57 100644
--- a/airflow-core/src/airflow/cli/commands/db_command.py
+++ b/airflow-core/src/airflow/cli/commands/db_command.py
@@ -44,7 +44,7 @@ log = logging.getLogger(__name__)
@providers_configuration_loaded
def resetdb(args):
"""Reset the metadata database."""
- print(f"DB: {settings.engine.url!r}")
+ print(f"DB: {settings.get_engine().url!r}")
if not (args.yes or input("This will drop existing tables if they exist.
Proceed? (y/n)").upper() == "Y"):
raise SystemExit("Cancelled")
db.resetdb(skip_init=args.skip_init)
@@ -94,7 +94,7 @@ def run_db_migrate_command(args, command, revision_heads_map:
dict[str, str]):
:meta private:
"""
- print(f"DB: {settings.engine.url!r}")
+ print(f"DB: {settings.get_engine().url!r}")
if args.to_revision and args.to_version:
raise SystemExit("Cannot supply both `--to-revision` and
`--to-version`.")
if args.from_version and args.from_revision:
@@ -128,7 +128,7 @@ def run_db_migrate_command(args, command,
revision_heads_map: dict[str, str]):
to_revision = args.to_revision
if not args.show_sql_only:
- print(f"Performing upgrade to the metadata database
{settings.engine.url!r}")
+ print(f"Performing upgrade to the metadata database
{settings.get_engine().url!r}")
else:
print("Generating sql for upgrade -- upgrade commands will *not* be
submitted.")
command(
@@ -172,7 +172,7 @@ def run_db_downgrade_command(args, command,
revision_heads_map: dict[str, str]):
elif args.to_revision:
to_revision = args.to_revision
if not args.show_sql_only:
- print(f"Performing downgrade with database {settings.engine.url!r}")
+ print(f"Performing downgrade with database
{settings.get_engine().url!r}")
else:
print("Generating sql for downgrade -- downgrade commands will *not*
be submitted.")
@@ -231,7 +231,7 @@ def _quote_mysql_password_for_cnf(password: str | None) ->
str:
@providers_configuration_loaded
def shell(args):
"""Run a shell that allows to access metadata database."""
- url = settings.engine.url
+ url = settings.get_engine().url
print(f"DB: {url!r}")
if url.get_backend_name() == "mysql":
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index ba026848f71..b86c12c9ff4 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -1089,7 +1089,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
self._run_scheduler_loop()
- settings.Session.remove()
+ if settings.Session is not None:
+ settings.Session.remove()
except Exception:
self.log.exception("Exception when executing
SchedulerJob._run_scheduler_loop")
raise
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index 49a4ddd95ec..81c44eb2b0b 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -1170,7 +1170,8 @@ class TaskInstance(Base, LoggingMixin):
# Closing all pooled connections to prevent
# "max number of connections reached"
- settings.engine.dispose()
+ if settings.engine is not None:
+ settings.engine.dispose()
if verbose:
if mark_success:
cls.logger().info("Marking success for %s on %s", ti.task,
ti.logical_date)
@@ -1638,7 +1639,7 @@ class TaskInstance(Base, LoggingMixin):
"""
# Do not use provide_session here -- it expunges everything on exit!
if not session:
- session = settings.Session()
+ session = settings.get_session()()
from airflow.exceptions import NotMapped
from airflow.models.mappedoperator import get_mapped_ti_count
diff --git a/airflow-core/src/airflow/settings.py
b/airflow-core/src/airflow/settings.py
index e2d96d96f52..503fb93acdf 100644
--- a/airflow-core/src/airflow/settings.py
+++ b/airflow-core/src/airflow/settings.py
@@ -31,8 +31,18 @@ from typing import TYPE_CHECKING, Any, Literal
import pluggy
from packaging.version import Version
from sqlalchemy import create_engine
-from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as
SAAsyncSession, create_async_engine
+from sqlalchemy.ext.asyncio import (
+ AsyncEngine,
+ AsyncSession as SAAsyncSession,
+ create_async_engine,
+)
from sqlalchemy.orm import scoped_session, sessionmaker
+
+try:
+ from sqlalchemy.ext.asyncio import async_sessionmaker
+except ImportError:
+ async_sessionmaker = sessionmaker # type: ignore[assignment,misc]
+
from sqlalchemy.pool import NullPool
from airflow import __version__ as airflow_version, policies
@@ -111,15 +121,30 @@ Mapping of sync scheme to async scheme.
:meta private:
"""
-engine: Engine
-Session: scoped_session
+engine: Engine | None = None
+Session: scoped_session | None = None
# NonScopedSession creates global sessions and is not safe to use in
multi-threaded environment without
# additional precautions. The only use case is when the session lifecycle needs
# custom handling. Most of the time we only want one unique thread local
session object,
# this is achieved by the Session factory above.
-NonScopedSession: sessionmaker
-async_engine: AsyncEngine
-AsyncSession: Callable[..., SAAsyncSession]
+NonScopedSession: sessionmaker | None = None
+async_engine: AsyncEngine | None = None
+AsyncSession: Callable[..., SAAsyncSession] | None = None
+
+
+def get_engine():
+ """Get the configured engine, raising an error if not configured."""
+ if engine is None:
+ raise RuntimeError("Engine not configured. Call configure_orm()
first.")
+ return engine
+
+
+def get_session():
+ """Get the configured Session, raising an error if not configured."""
+ if Session is None:
+ raise RuntimeError("Session not configured. Call configure_orm()
first.")
+ return Session
+
# The JSON library to use for DAG Serialization and De-Serialization
json = json_lib
@@ -353,19 +378,22 @@ def _configure_async_session() -> None:
this does not work well with Pytest and you can end up with issues when the
session and runs in a different event loop from the test itself.
"""
- global AsyncSession
- global async_engine
+ global AsyncSession, async_engine
+
+ if not SQL_ALCHEMY_CONN_ASYNC:
+ async_engine = None
+ AsyncSession = None
+ return
async_engine = create_async_engine(
SQL_ALCHEMY_CONN_ASYNC,
connect_args=_get_connect_args("async"),
future=True,
)
- AsyncSession = sessionmaker(
+ AsyncSession = async_sessionmaker(
bind=async_engine,
- autocommit=False,
- autoflush=False,
class_=SAAsyncSession,
+ autoflush=False,
expire_on_commit=False,
)
@@ -420,6 +448,8 @@ def configure_orm(disable_connection_pool=False,
pool_class=None):
autoflush=False,
expire_on_commit=False,
)
+ if engine is None:
+ raise RuntimeError("Engine must be initialized before creating a
session")
NonScopedSession = _session_maker(engine)
Session = scoped_session(NonScopedSession)
diff --git a/airflow-core/src/airflow/utils/db.py
b/airflow-core/src/airflow/utils/db.py
index 50cceed94cf..d969bccc43b 100644
--- a/airflow-core/src/airflow/utils/db.py
+++ b/airflow-core/src/airflow/utils/db.py
@@ -828,7 +828,7 @@ def _configured_alembic_environment() ->
Generator[EnvironmentContext, None, Non
config,
script,
) as env,
- settings.engine.connect() as connection,
+ settings.get_engine().connect() as connection,
):
alembic_logger = logging.getLogger("alembic")
level = alembic_logger.level
@@ -1044,7 +1044,7 @@ def _revisions_above_min_for_offline(config, revisions)
-> None:
:param revisions: list of Alembic revision ids
:return: None
"""
- dbname = settings.engine.dialect.name
+ dbname = settings.get_engine().dialect.name
if dbname == "sqlite":
raise SystemExit("Offline migration not supported for SQLite.")
min_version, min_revision = ("2.7.0", "937cbd173ca1")
@@ -1257,7 +1257,7 @@ def _handle_fab_downgrade(*, session: Session) -> None:
fab_version,
)
return
- connection = settings.engine.connect()
+ connection = settings.get_engine().connect()
insp = inspect(connection)
if not fab_version and insp.has_table("ab_user"):
log.info(
diff --git a/airflow-core/src/airflow/utils/task_instance_session.py
b/airflow-core/src/airflow/utils/task_instance_session.py
index bb9741bf525..019a752c773 100644
--- a/airflow-core/src/airflow/utils/task_instance_session.py
+++ b/airflow-core/src/airflow/utils/task_instance_session.py
@@ -41,7 +41,7 @@ def get_current_task_instance_session() -> Session:
log.warning('File: "%s", %s , in %s', filename, line_number, name)
if line:
log.warning(" %s", line.strip())
- __current_task_instance_session = settings.Session()
+ __current_task_instance_session = settings.get_session()()
return __current_task_instance_session
diff --git a/airflow-core/tests/unit/models/test_dag.py
b/airflow-core/tests/unit/models/test_dag.py
index 3e341dfc528..7f7e30fa45a 100644
--- a/airflow-core/tests/unit/models/test_dag.py
+++ b/airflow-core/tests/unit/models/test_dag.py
@@ -1430,7 +1430,7 @@ my_postgres_conn:
) as dag:
EmptyOperator(task_id=task_id)
- session = settings.Session()
+ session = settings.get_session()()
dagrun_1 = dag_maker.create_dagrun(
run_id="backfill",
run_type=DagRunType.BACKFILL_JOB,