This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c4409597235 Disable ORM access from Tasks, DAG processing and Triggers 
(#47320)
c4409597235 is described below

commit c4409597235c3d9858964718a3f6d93ae23f1d80
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Mar 6 13:22:00 2025 +0000

    Disable ORM access from Tasks, DAG processing and Triggers (#47320)
    
    All of these use the Workload supervisor from the TaskSDK and the main paths
    (XCom, Variables and Secrets) have all been ported to use the Execution API,
    so it's about time we disabled DB access.
---
 airflow/dag_processing/processor.py                |  6 ---
 airflow/settings.py                                | 50 ++++++++-----------
 providers/celery/provider.yaml                     |  7 ---
 .../airflow/providers/celery/cli/celery_command.py | 10 ++--
 .../airflow/providers/celery/get_provider_info.py  |  7 ---
 .../tests/unit/celery/cli/test_celery_command.py   | 34 -------------
 .../src/airflow/sdk/execution_time/supervisor.py   | 58 ++++++++++++++++++++++
 task_sdk/tests/conftest.py                         |  5 ++
 task_sdk/tests/execution_time/test_supervisor.py   |  3 ++
 tests/dag_processing/test_manager.py               | 11 ++--
 tests/jobs/test_triggerer_job.py                   |  6 ++-
 11 files changed, 104 insertions(+), 93 deletions(-)

diff --git a/airflow/dag_processing/processor.py 
b/airflow/dag_processing/processor.py
index bba6b6857d8..c5420d0b21f 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -63,14 +63,8 @@ def _parse_file_entrypoint():
     import structlog
 
     from airflow.sdk.execution_time import task_runner
-    from airflow.settings import configure_orm
 
     # Parse DAG file, send JSON back up!
-
-    # We need to reconfigure the orm here, as DagFileProcessorManager does db 
queries for bundles, and
-    # the session across forks blows things up.
-    configure_orm()
-
     comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager](
         input=sys.stdin,
         decoder=TypeAdapter[ToDagProcessor](ToDagProcessor),
diff --git a/airflow/settings.py b/airflow/settings.py
index 307ee1e668a..6ae462a1e09 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any, Callable
 
 import pluggy
 from packaging.version import Version
-from sqlalchemy import create_engine, exc, text
+from sqlalchemy import create_engine
 from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as 
SAAsyncSession, create_async_engine
 from sqlalchemy.orm import scoped_session, sessionmaker
 from sqlalchemy.pool import NullPool
@@ -46,7 +46,6 @@ from airflow.utils.timezone import local_timezone, 
parse_timezone, utc
 
 if TYPE_CHECKING:
     from sqlalchemy.engine import Engine
-    from sqlalchemy.orm import Session as SASession
 
 log = logging.getLogger(__name__)
 
@@ -101,12 +100,12 @@ Mapping of sync scheme to async scheme.
 """
 
 engine: Engine
-Session: Callable[..., SASession]
+Session: scoped_session
 # NonScopedSession creates global sessions and is not safe to use in 
multi-threaded environment without
 # additional precautions. The only use case is when the session lifecycle needs
 # custom handling. Most of the time we only want one unique thread local 
session object,
 # this is achieved by the Session factory above.
-NonScopedSession: Callable[..., SASession]
+NonScopedSession: sessionmaker
 async_engine: AsyncEngine
 AsyncSession: Callable[..., SAAsyncSession]
 
@@ -389,6 +388,12 @@ def configure_orm(disable_connection_pool=False, 
pool_class=None):
     NonScopedSession = _session_maker(engine)
     Session = scoped_session(NonScopedSession)
 
+    from sqlalchemy.orm.session import close_all_sessions
+
+    os.register_at_fork(after_in_child=close_all_sessions)
+    # 
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
+    os.register_at_fork(after_in_child=lambda: engine.dispose(close=False))
+
 
 DEFAULT_ENGINE_ARGS = {
     "postgresql": {
@@ -479,14 +484,23 @@ def prepare_engine_args(disable_connection_pool=False, 
pool_class=None):
 
 def dispose_orm():
     """Properly close pooled database connections."""
+    global Session, engine, NonScopedSession
+
+    _globals = globals()
+    if "engine" not in _globals and "Session" not in _globals:
+        return
+
     log.debug("Disposing DB connection pool (PID %s)", os.getpid())
-    global engine
-    global Session
 
-    if Session is not None:  # type: ignore[truthy-function]
+    if "Session" in _globals and Session is not None:
+        from sqlalchemy.orm.session import close_all_sessions
+
         Session.remove()
         Session = None
-    if engine:
+        NonScopedSession = None
+        close_all_sessions()
+
+    if "engine" in _globals:
         engine.dispose()
         engine = None
 
@@ -529,26 +543,6 @@ def configure_adapters():
             pass
 
 
-def validate_session():
-    """Validate ORM Session."""
-    global engine
-
-    worker_precheck = conf.getboolean("celery", "worker_precheck")
-    if not worker_precheck:
-        return True
-    else:
-        check_session = sessionmaker(bind=engine)
-        session = check_session()
-        try:
-            session.execute(text("select 1"))
-            conn_status = True
-        except exc.DBAPIError as err:
-            log.error(err)
-            conn_status = False
-        session.close()
-        return conn_status
-
-
 def configure_action_logging() -> None:
     """Any additional configuration (register callback) for 
airflow.utils.action_loggers module."""
 
diff --git a/providers/celery/provider.yaml b/providers/celery/provider.yaml
index 0f50ee7c7b4..7d44bf2989c 100644
--- a/providers/celery/provider.yaml
+++ b/providers/celery/provider.yaml
@@ -308,13 +308,6 @@ config:
         type: integer
         example: ~
         default: "3"
-      worker_precheck:
-        description: |
-          Worker initialisation check to validate Metadata Database connection
-        version_added: ~
-        type: string
-        example: ~
-        default: "False"
       extra_celery_config:
         description: |
           Extra celery configs to include in the celery worker.
diff --git 
a/providers/celery/src/airflow/providers/celery/cli/celery_command.py 
b/providers/celery/src/airflow/providers/celery/cli/celery_command.py
index 464cd830184..8381886f702 100644
--- a/providers/celery/src/airflow/providers/celery/cli/celery_command.py
+++ b/providers/celery/src/airflow/providers/celery/cli/celery_command.py
@@ -197,11 +197,11 @@ def worker(args):
         from airflow.sdk.log import configure_logging
 
         configure_logging(output=sys.stdout.buffer)
-
-    # Disable connection pool so that celery worker does not hold an 
unnecessary db connection
-    settings.reconfigure_orm(disable_connection_pool=True)
-    if not settings.validate_session():
-        raise SystemExit("Worker exiting, database connection precheck 
failed.")
+    else:
+        # Disable connection pool so that celery worker does not hold an 
unnecessary db connection
+        settings.reconfigure_orm(disable_connection_pool=True)
+        if not settings.validate_session():
+            raise SystemExit("Worker exiting, database connection precheck 
failed.")
 
     autoscale = args.autoscale
     skip_serve_logs = args.skip_serve_logs
diff --git a/providers/celery/src/airflow/providers/celery/get_provider_info.py 
b/providers/celery/src/airflow/providers/celery/get_provider_info.py
index 41f8543336a..625cf867f10 100644
--- a/providers/celery/src/airflow/providers/celery/get_provider_info.py
+++ b/providers/celery/src/airflow/providers/celery/get_provider_info.py
@@ -266,13 +266,6 @@ def get_provider_info():
                         "example": None,
                         "default": "3",
                     },
-                    "worker_precheck": {
-                        "description": "Worker initialisation check to 
validate Metadata Database connection\n",
-                        "version_added": None,
-                        "type": "string",
-                        "example": None,
-                        "default": "False",
-                    },
                     "extra_celery_config": {
                         "description": 'Extra celery configs to include in the 
celery worker.\nAny of the celery config can be added to this config and 
it\nwill be applied while starting the celery worker. e.g. 
{"worker_max_tasks_per_child": 10}\nSee 
also:\nhttps://docs.celeryq.dev/en/stable/userguide/configuration.html#configuration-and-defaults\n',
                         "version_added": None,
diff --git a/providers/celery/tests/unit/celery/cli/test_celery_command.py 
b/providers/celery/tests/unit/celery/cli/test_celery_command.py
index 93b8606aa54..06f96876a1f 100644
--- a/providers/celery/tests/unit/celery/cli/test_celery_command.py
+++ b/providers/celery/tests/unit/celery/cli/test_celery_command.py
@@ -19,14 +19,11 @@ from __future__ import annotations
 
 import importlib
 import os
-from argparse import Namespace
 from unittest import mock
 from unittest.mock import patch
 
 import pytest
-import sqlalchemy
 
-import airflow
 from airflow.cli import cli_parser
 from airflow.configuration import conf
 from airflow.executors import executor_loader
@@ -39,37 +36,6 @@ from tests_common.test_utils.version_compat import 
AIRFLOW_V_2_10_PLUS, AIRFLOW_
 pytestmark = pytest.mark.db_test
 
 
-@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
-class TestWorkerPrecheck:
-    @mock.patch("airflow.settings.validate_session")
-    def test_error(self, mock_validate_session):
-        """
-        Test to verify the exit mechanism of airflow-worker cli
-        by mocking validate_session method
-        """
-        mock_validate_session.return_value = False
-        with pytest.raises(SystemExit) as ctx, conf_vars({("core", 
"executor"): "CeleryExecutor"}):
-            celery_command.worker(Namespace(queues=1, concurrency=1))
-        assert str(ctx.value) == "Worker exiting, database connection precheck 
failed."
-
-    @conf_vars({("celery", "worker_precheck"): "False"})
-    def test_worker_precheck_exception(self):
-        """
-        Test to check the behaviour of validate_session method
-        when worker_precheck is absent in airflow configuration
-        """
-        assert airflow.settings.validate_session()
-
-    @mock.patch("sqlalchemy.orm.session.Session.execute")
-    @conf_vars({("celery", "worker_precheck"): "True"})
-    def test_validate_session_dbapi_exception(self, mock_session):
-        """
-        Test to validate connection failure scenario on SELECT 1 query
-        """
-        mock_session.side_effect = sqlalchemy.exc.OperationalError("m1", "m2", 
"m3", "m4")
-        assert airflow.settings.validate_session() is False
-
-
 @pytest.mark.backend("mysql", "postgres")
 @conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
 class TestCeleryStopCommand:
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index cfc9b0d8fef..a93fae2f0b3 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -206,6 +206,62 @@ def _get_last_chance_stderr() -> TextIO:
         return stream
 
 
+class BlockedDBSession:
+    """:meta private:"""  # noqa: D400
+
+    def __init__(self):
+        raise RuntimeError("Direct database access via the ORM is not allowed 
in Airflow 3.0")
+
+    def remove(*args, **kwargs):
+        pass
+
+    def get_bind(
+        self,
+        mapper=None,
+        clause=None,
+        bind=None,
+        _sa_skip_events=None,
+        _sa_skip_for_implicit_returning=False,
+    ):
+        pass
+
+
+def block_orm_access():
+    """
+    Disable direct DB access as best as possible from task code.
+
+    While we still don't have 100% code separation between TaskSDK and "core" 
Airflow, it is still possible to
+    import the models and use them. This does what it can to disable that if 
it is not blocked at the network
+    level
+    """
+    # A fake URL schema that might give users some clue what's going on. 
Hopefully
+    conn = "airflow-db-not-allowed:///"
+    if "airflow.settings" in sys.modules:
+        from airflow import settings
+        from airflow.configuration import conf
+
+        settings.dispose_orm()
+
+        for attr in ("engine", "async_engine", "Session", "AsyncSession", 
"NonScopedSession"):
+            if hasattr(settings, attr):
+                delattr(settings, attr)
+
+        def configure_orm(*args, **kwargs):
+            raise RuntimeError("Database access is disabled from DAGs and 
Triggers")
+
+        settings.configure_orm = configure_orm
+        settings.Session = BlockedDBSession
+        if conf.has_section("database"):
+            conf.set("database", "sql_alchemy_conn", conn)
+            conf.set("database", "sql_alchemy_conn_cmd", "/bin/false")
+            conf.set("database", "sql_alchemy_conn_secret", 
"db-access-blocked")
+
+        settings.SQL_ALCHEMY_CONN = conn
+        settings.SQL_ALCHEMY_CONN_ASYNC = conn
+
+    os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"] = conn
+
+
 def _fork_main(
     child_stdin: socket,
     child_stdout: socket,
@@ -261,6 +317,8 @@ def _fork_main(
                 base_exit(n)
 
     try:
+        block_orm_access()
+
         target()
         exit(0)
     except SystemExit as e:
diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py
index 3196051da7b..b075d2e73ad 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/conftest.py
@@ -28,6 +28,7 @@ pytest_plugins = "tests_common.pytest_plugin"
 
 # Task SDK does not need access to the Airflow database
 os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
+os.environ["_AIRFLOW__AS_LIBRARY"] = "true"
 
 if TYPE_CHECKING:
     from datetime import datetime
@@ -56,6 +57,10 @@ def pytest_configure(config: pytest.Config) -> None:
     # Always skip looking for tests in these folders!
     config.addinivalue_line("norecursedirs", "tests/test_dags")
 
+    import airflow.settings
+
+    airflow.settings.configure_policy_plugin_manager()
+
 
 @pytest.hookimpl(tryfirst=True)
 def pytest_runtest_setup(item):
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 04097589770..ea639d9bdff 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -24,6 +24,7 @@ import os
 import selectors
 import signal
 import sys
+import time
 from io import BytesIO
 from operator import attrgetter
 from pathlib import Path
@@ -850,7 +851,9 @@ class TestWatchedSubprocessKill:
             client=MagicMock(spec=sdk_client.Client),
             target=subprocess_main,
         )
+
         # Ensure we get one normal run, to give the proc time to register it's 
custom sighandler
+        time.sleep(0.1)
         proc._service_subprocess(max_wait_time=1)
         proc.kill(signal_to_send=signal_to_send, escalation_delay=0.5, 
force=True)
 
diff --git a/tests/dag_processing/test_manager.py 
b/tests/dag_processing/test_manager.py
index 19d5c3f32a7..4e605e02da3 100644
--- a/tests/dag_processing/test_manager.py
+++ b/tests/dag_processing/test_manager.py
@@ -165,16 +165,18 @@ class TestDagFileProcessorManager:
                 processor_timeout=365 * 86_400,
             )
 
-            with create_session() as session:
-                manager.run()
+            manager.run()
 
+            with create_session() as session:
                 import_errors = session.query(ParseImportError).all()
                 assert len(import_errors) == 1
 
                 path_to_parse.unlink()
 
-                # Rerun the parser once the dag file has been removed
-                manager.run()
+            # Rerun the parser once the dag file has been removed
+            manager.run()
+
+            with create_session() as session:
                 import_errors = session.query(ParseImportError).all()
 
                 assert len(import_errors) == 0
@@ -658,6 +660,7 @@ class TestDagFileProcessorManager:
         shutil.copy(source_location, zip_dag_path)
 
         with configure_testing_dag_bundle(bundle_path):
+            session.commit()
             manager = DagFileProcessorManager(max_runs=1)
             manager.run()
 
diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py
index 7d065dd2562..9471789b2a8 100644
--- a/tests/jobs/test_triggerer_job.py
+++ b/tests/jobs/test_triggerer_job.py
@@ -179,11 +179,13 @@ def test_trigger_lifecycle(spy_agency: SpyAgency, 
session):
     trigger = TimeDeltaTrigger(datetime.timedelta(days=7))
     dag_model, run, trigger_orm, task_instance = create_trigger_in_db(session, 
trigger)
     # Make a TriggererJobRunner and have it retrieve DB tasks
-    trigger_runner_supervisor = TriggerRunnerSupervisor.start(job=Job(), 
capacity=10)
+    trigger_runner_supervisor = 
TriggerRunnerSupervisor.start(job=Job(id=12345), capacity=10)
 
     try:
         # Spy on it so we can see what gets send, but also call the original.
         send_spy = spy_agency.spy_on(TriggerRunnerSupervisor._send, 
owner=TriggerRunnerSupervisor)
+
+        trigger_runner_supervisor._service_subprocess(0.1)
         trigger_runner_supervisor.load_triggers()
         # Make sure it turned up in TriggerRunner's queue
         assert trigger_runner_supervisor.running_triggers == {1}
@@ -431,7 +433,7 @@ def test_trigger_create_race_condition_18392(session, 
supervisor_builder, spy_ag
 
 
 @pytest.mark.execution_timeout(5)
-def test_trigger_runner_exception_stops_triggerer(session):
+def test_trigger_runner_exception_stops_triggerer():
     """
     Checks that if an exception occurs when creating triggers, that the 
triggerer
     process stops

Reply via email to