This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c4409597235 Disable ORM access from Tasks, DAG processing and Triggers
(#47320)
c4409597235 is described below
commit c4409597235c3d9858964718a3f6d93ae23f1d80
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Mar 6 13:22:00 2025 +0000
Disable ORM access from Tasks, DAG processing and Triggers (#47320)
All of these use the Workload supervisor from the TaskSDK and the main paths
(XCom, Variables and Secrets) have all been ported to use the Execution API,
so it's about time we disabled DB access.
---
airflow/dag_processing/processor.py | 6 ---
airflow/settings.py | 50 ++++++++-----------
providers/celery/provider.yaml | 7 ---
.../airflow/providers/celery/cli/celery_command.py | 10 ++--
.../airflow/providers/celery/get_provider_info.py | 7 ---
.../tests/unit/celery/cli/test_celery_command.py | 34 -------------
.../src/airflow/sdk/execution_time/supervisor.py | 58 ++++++++++++++++++++++
task_sdk/tests/conftest.py | 5 ++
task_sdk/tests/execution_time/test_supervisor.py | 3 ++
tests/dag_processing/test_manager.py | 11 ++--
tests/jobs/test_triggerer_job.py | 6 ++-
11 files changed, 104 insertions(+), 93 deletions(-)
diff --git a/airflow/dag_processing/processor.py
b/airflow/dag_processing/processor.py
index bba6b6857d8..c5420d0b21f 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -63,14 +63,8 @@ def _parse_file_entrypoint():
import structlog
from airflow.sdk.execution_time import task_runner
- from airflow.settings import configure_orm
# Parse DAG file, send JSON back up!
-
- # We need to reconfigure the orm here, as DagFileProcessorManager does db
queries for bundles, and
- # the session across forks blows things up.
- configure_orm()
-
comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager](
input=sys.stdin,
decoder=TypeAdapter[ToDagProcessor](ToDagProcessor),
diff --git a/airflow/settings.py b/airflow/settings.py
index 307ee1e668a..6ae462a1e09 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any, Callable
import pluggy
from packaging.version import Version
-from sqlalchemy import create_engine, exc, text
+from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as
SAAsyncSession, create_async_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import NullPool
@@ -46,7 +46,6 @@ from airflow.utils.timezone import local_timezone,
parse_timezone, utc
if TYPE_CHECKING:
from sqlalchemy.engine import Engine
- from sqlalchemy.orm import Session as SASession
log = logging.getLogger(__name__)
@@ -101,12 +100,12 @@ Mapping of sync scheme to async scheme.
"""
engine: Engine
-Session: Callable[..., SASession]
+Session: scoped_session
# NonScopedSession creates global sessions and is not safe to use in
multi-threaded environment without
# additional precautions. The only use case is when the session lifecycle needs
# custom handling. Most of the time we only want one unique thread local
session object,
# this is achieved by the Session factory above.
-NonScopedSession: Callable[..., SASession]
+NonScopedSession: sessionmaker
async_engine: AsyncEngine
AsyncSession: Callable[..., SAAsyncSession]
@@ -389,6 +388,12 @@ def configure_orm(disable_connection_pool=False,
pool_class=None):
NonScopedSession = _session_maker(engine)
Session = scoped_session(NonScopedSession)
+ from sqlalchemy.orm.session import close_all_sessions
+
+ os.register_at_fork(after_in_child=close_all_sessions)
+ #
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
+ os.register_at_fork(after_in_child=lambda: engine.dispose(close=False))
+
DEFAULT_ENGINE_ARGS = {
"postgresql": {
@@ -479,14 +484,23 @@ def prepare_engine_args(disable_connection_pool=False,
pool_class=None):
def dispose_orm():
"""Properly close pooled database connections."""
+ global Session, engine, NonScopedSession
+
+ _globals = globals()
+ if "engine" not in _globals and "Session" not in _globals:
+ return
+
log.debug("Disposing DB connection pool (PID %s)", os.getpid())
- global engine
- global Session
- if Session is not None: # type: ignore[truthy-function]
+ if "Session" in _globals and Session is not None:
+ from sqlalchemy.orm.session import close_all_sessions
+
Session.remove()
Session = None
- if engine:
+ NonScopedSession = None
+ close_all_sessions()
+
+ if "engine" in _globals:
engine.dispose()
engine = None
@@ -529,26 +543,6 @@ def configure_adapters():
pass
-def validate_session():
- """Validate ORM Session."""
- global engine
-
- worker_precheck = conf.getboolean("celery", "worker_precheck")
- if not worker_precheck:
- return True
- else:
- check_session = sessionmaker(bind=engine)
- session = check_session()
- try:
- session.execute(text("select 1"))
- conn_status = True
- except exc.DBAPIError as err:
- log.error(err)
- conn_status = False
- session.close()
- return conn_status
-
-
def configure_action_logging() -> None:
"""Any additional configuration (register callback) for
airflow.utils.action_loggers module."""
diff --git a/providers/celery/provider.yaml b/providers/celery/provider.yaml
index 0f50ee7c7b4..7d44bf2989c 100644
--- a/providers/celery/provider.yaml
+++ b/providers/celery/provider.yaml
@@ -308,13 +308,6 @@ config:
type: integer
example: ~
default: "3"
- worker_precheck:
- description: |
- Worker initialisation check to validate Metadata Database connection
- version_added: ~
- type: string
- example: ~
- default: "False"
extra_celery_config:
description: |
Extra celery configs to include in the celery worker.
diff --git
a/providers/celery/src/airflow/providers/celery/cli/celery_command.py
b/providers/celery/src/airflow/providers/celery/cli/celery_command.py
index 464cd830184..8381886f702 100644
--- a/providers/celery/src/airflow/providers/celery/cli/celery_command.py
+++ b/providers/celery/src/airflow/providers/celery/cli/celery_command.py
@@ -197,11 +197,11 @@ def worker(args):
from airflow.sdk.log import configure_logging
configure_logging(output=sys.stdout.buffer)
-
- # Disable connection pool so that celery worker does not hold an
unnecessary db connection
- settings.reconfigure_orm(disable_connection_pool=True)
- if not settings.validate_session():
- raise SystemExit("Worker exiting, database connection precheck
failed.")
+ else:
+ # Disable connection pool so that celery worker does not hold an
unnecessary db connection
+ settings.reconfigure_orm(disable_connection_pool=True)
+ if not settings.validate_session():
+ raise SystemExit("Worker exiting, database connection precheck
failed.")
autoscale = args.autoscale
skip_serve_logs = args.skip_serve_logs
diff --git a/providers/celery/src/airflow/providers/celery/get_provider_info.py
b/providers/celery/src/airflow/providers/celery/get_provider_info.py
index 41f8543336a..625cf867f10 100644
--- a/providers/celery/src/airflow/providers/celery/get_provider_info.py
+++ b/providers/celery/src/airflow/providers/celery/get_provider_info.py
@@ -266,13 +266,6 @@ def get_provider_info():
"example": None,
"default": "3",
},
- "worker_precheck": {
- "description": "Worker initialisation check to
validate Metadata Database connection\n",
- "version_added": None,
- "type": "string",
- "example": None,
- "default": "False",
- },
"extra_celery_config": {
"description": 'Extra celery configs to include in the
celery worker.\nAny of the celery config can be added to this config and
it\nwill be applied while starting the celery worker. e.g.
{"worker_max_tasks_per_child": 10}\nSee
also:\nhttps://docs.celeryq.dev/en/stable/userguide/configuration.html#configuration-and-defaults\n',
"version_added": None,
diff --git a/providers/celery/tests/unit/celery/cli/test_celery_command.py
b/providers/celery/tests/unit/celery/cli/test_celery_command.py
index 93b8606aa54..06f96876a1f 100644
--- a/providers/celery/tests/unit/celery/cli/test_celery_command.py
+++ b/providers/celery/tests/unit/celery/cli/test_celery_command.py
@@ -19,14 +19,11 @@ from __future__ import annotations
import importlib
import os
-from argparse import Namespace
from unittest import mock
from unittest.mock import patch
import pytest
-import sqlalchemy
-import airflow
from airflow.cli import cli_parser
from airflow.configuration import conf
from airflow.executors import executor_loader
@@ -39,37 +36,6 @@ from tests_common.test_utils.version_compat import
AIRFLOW_V_2_10_PLUS, AIRFLOW_
pytestmark = pytest.mark.db_test
-@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
-class TestWorkerPrecheck:
- @mock.patch("airflow.settings.validate_session")
- def test_error(self, mock_validate_session):
- """
- Test to verify the exit mechanism of airflow-worker cli
- by mocking validate_session method
- """
- mock_validate_session.return_value = False
- with pytest.raises(SystemExit) as ctx, conf_vars({("core",
"executor"): "CeleryExecutor"}):
- celery_command.worker(Namespace(queues=1, concurrency=1))
- assert str(ctx.value) == "Worker exiting, database connection precheck
failed."
-
- @conf_vars({("celery", "worker_precheck"): "False"})
- def test_worker_precheck_exception(self):
- """
- Test to check the behaviour of validate_session method
- when worker_precheck is absent in airflow configuration
- """
- assert airflow.settings.validate_session()
-
- @mock.patch("sqlalchemy.orm.session.Session.execute")
- @conf_vars({("celery", "worker_precheck"): "True"})
- def test_validate_session_dbapi_exception(self, mock_session):
- """
- Test to validate connection failure scenario on SELECT 1 query
- """
- mock_session.side_effect = sqlalchemy.exc.OperationalError("m1", "m2",
"m3", "m4")
- assert airflow.settings.validate_session() is False
-
-
@pytest.mark.backend("mysql", "postgres")
@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
class TestCeleryStopCommand:
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index cfc9b0d8fef..a93fae2f0b3 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -206,6 +206,62 @@ def _get_last_chance_stderr() -> TextIO:
return stream
+class BlockedDBSession:
+ """:meta private:""" # noqa: D400
+
+ def __init__(self):
+ raise RuntimeError("Direct database access via the ORM is not allowed
in Airflow 3.0")
+
+ def remove(*args, **kwargs):
+ pass
+
+ def get_bind(
+ self,
+ mapper=None,
+ clause=None,
+ bind=None,
+ _sa_skip_events=None,
+ _sa_skip_for_implicit_returning=False,
+ ):
+ pass
+
+
+def block_orm_access():
+ """
+ Disable direct DB access as best as possible from task code.
+
+ While we still don't have 100% code separation between TaskSDK and "core"
Airflow, it is still possible to
+ import the models and use them. This does what it can to disable that if
it is not blocked at the network
+ level
+ """
+ # A fake URL schema that might give users some clue what's going on.
Hopefully
+ conn = "airflow-db-not-allowed:///"
+ if "airflow.settings" in sys.modules:
+ from airflow import settings
+ from airflow.configuration import conf
+
+ settings.dispose_orm()
+
+ for attr in ("engine", "async_engine", "Session", "AsyncSession",
"NonScopedSession"):
+ if hasattr(settings, attr):
+ delattr(settings, attr)
+
+ def configure_orm(*args, **kwargs):
+ raise RuntimeError("Database access is disabled from DAGs and
Triggers")
+
+ settings.configure_orm = configure_orm
+ settings.Session = BlockedDBSession
+ if conf.has_section("database"):
+ conf.set("database", "sql_alchemy_conn", conn)
+ conf.set("database", "sql_alchemy_conn_cmd", "/bin/false")
+ conf.set("database", "sql_alchemy_conn_secret",
"db-access-blocked")
+
+ settings.SQL_ALCHEMY_CONN = conn
+ settings.SQL_ALCHEMY_CONN_ASYNC = conn
+
+ os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"] = conn
+
+
def _fork_main(
child_stdin: socket,
child_stdout: socket,
@@ -261,6 +317,8 @@ def _fork_main(
base_exit(n)
try:
+ block_orm_access()
+
target()
exit(0)
except SystemExit as e:
diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py
index 3196051da7b..b075d2e73ad 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/conftest.py
@@ -28,6 +28,7 @@ pytest_plugins = "tests_common.pytest_plugin"
# Task SDK does not need access to the Airflow database
os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
+os.environ["_AIRFLOW__AS_LIBRARY"] = "true"
if TYPE_CHECKING:
from datetime import datetime
@@ -56,6 +57,10 @@ def pytest_configure(config: pytest.Config) -> None:
# Always skip looking for tests in these folders!
config.addinivalue_line("norecursedirs", "tests/test_dags")
+ import airflow.settings
+
+ airflow.settings.configure_policy_plugin_manager()
+
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_setup(item):
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index 04097589770..ea639d9bdff 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -24,6 +24,7 @@ import os
import selectors
import signal
import sys
+import time
from io import BytesIO
from operator import attrgetter
from pathlib import Path
@@ -850,7 +851,9 @@ class TestWatchedSubprocessKill:
client=MagicMock(spec=sdk_client.Client),
target=subprocess_main,
)
+
# Ensure we get one normal run, to give the proc time to register it's
custom sighandler
+ time.sleep(0.1)
proc._service_subprocess(max_wait_time=1)
proc.kill(signal_to_send=signal_to_send, escalation_delay=0.5,
force=True)
diff --git a/tests/dag_processing/test_manager.py
b/tests/dag_processing/test_manager.py
index 19d5c3f32a7..4e605e02da3 100644
--- a/tests/dag_processing/test_manager.py
+++ b/tests/dag_processing/test_manager.py
@@ -165,16 +165,18 @@ class TestDagFileProcessorManager:
processor_timeout=365 * 86_400,
)
- with create_session() as session:
- manager.run()
+ manager.run()
+ with create_session() as session:
import_errors = session.query(ParseImportError).all()
assert len(import_errors) == 1
path_to_parse.unlink()
- # Rerun the parser once the dag file has been removed
- manager.run()
+ # Rerun the parser once the dag file has been removed
+ manager.run()
+
+ with create_session() as session:
import_errors = session.query(ParseImportError).all()
assert len(import_errors) == 0
@@ -658,6 +660,7 @@ class TestDagFileProcessorManager:
shutil.copy(source_location, zip_dag_path)
with configure_testing_dag_bundle(bundle_path):
+ session.commit()
manager = DagFileProcessorManager(max_runs=1)
manager.run()
diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py
index 7d065dd2562..9471789b2a8 100644
--- a/tests/jobs/test_triggerer_job.py
+++ b/tests/jobs/test_triggerer_job.py
@@ -179,11 +179,13 @@ def test_trigger_lifecycle(spy_agency: SpyAgency,
session):
trigger = TimeDeltaTrigger(datetime.timedelta(days=7))
dag_model, run, trigger_orm, task_instance = create_trigger_in_db(session,
trigger)
# Make a TriggererJobRunner and have it retrieve DB tasks
- trigger_runner_supervisor = TriggerRunnerSupervisor.start(job=Job(),
capacity=10)
+ trigger_runner_supervisor =
TriggerRunnerSupervisor.start(job=Job(id=12345), capacity=10)
try:
# Spy on it so we can see what gets send, but also call the original.
send_spy = spy_agency.spy_on(TriggerRunnerSupervisor._send,
owner=TriggerRunnerSupervisor)
+
+ trigger_runner_supervisor._service_subprocess(0.1)
trigger_runner_supervisor.load_triggers()
# Make sure it turned up in TriggerRunner's queue
assert trigger_runner_supervisor.running_triggers == {1}
@@ -431,7 +433,7 @@ def test_trigger_create_race_condition_18392(session,
supervisor_builder, spy_ag
@pytest.mark.execution_timeout(5)
-def test_trigger_runner_exception_stops_triggerer(session):
+def test_trigger_runner_exception_stops_triggerer():
"""
Checks that if an exception occurs when creating triggers, that the
triggerer
process stops