This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 04c2ab5be6 Add session that blows up when using internal API (#38563)
04c2ab5be6 is described below

commit 04c2ab5be669550e4c4d1d004ed1fd1461e58f7e
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Apr 9 07:55:24 2024 -0700

    Add session that blows up when using internal API (#38563)
    
    Here I add a TracebackSession which is designed to blow up the first time 
it's used and then provide a traceback for the original call site.  I also 
ensure that this is the Session class that is used when internal API config is 
enabled.
---
 airflow/settings.py         | 39 +++++++++++++++++++++++++++--
 airflow/utils/session.py    |  5 ++++
 tests/conftest.py           |  2 --
 tests/core/test_settings.py | 61 +++++++++++++++++++++++++++++++++++++++++++++
 4 files changed, 103 insertions(+), 4 deletions(-)

diff --git a/airflow/settings.py b/airflow/settings.py
index 7ead89f34e..fa859fdf42 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -23,6 +23,7 @@ import json
 import logging
 import os
 import sys
+import traceback
 import warnings
 from typing import TYPE_CHECKING, Any, Callable
 
@@ -207,7 +208,11 @@ def configure_vars():
 
 
 class SkipDBTestsSession:
-    """This fake session is used to skip DB tests when 
`_AIRFLOW_SKIP_DB_TESTS` is set."""
+    """
+    This fake session is used to skip DB tests when `_AIRFLOW_SKIP_DB_TESTS` 
is set.
+
+    :meta private:
+    """
 
     def __init__(self):
         raise AirflowInternalRuntimeError(
@@ -222,6 +227,30 @@ class SkipDBTestsSession:
         pass
 
 
+class TracebackSession:
+    """
+    Session that throws error when you try to use it.
+
+    Also stores stack at instantiation call site.
+
+    :meta private:
+    """
+
+    def __init__(self):
+        self.traceback = traceback.extract_stack()
+
+    def __getattr__(self, item):
+        raise RuntimeError(
+            "TracebackSession object was used but internal API is enabled. "
+            "You'll need to ensure you are making only RPC calls with this 
object. "
+            "The stack list below will show where the TracebackSession object 
was created."
+            + "\n".join(traceback.format_list(self.traceback))
+        )
+
+    def remove(*args, **kwargs):
+        pass
+
+
 def configure_orm(disable_connection_pool=False, pool_class=None):
     """Configure ORM using SQLAlchemy."""
     from airflow.utils.log.secrets_masker import mask_secret
@@ -242,7 +271,13 @@ def configure_orm(disable_connection_pool=False, 
pool_class=None):
 
     global Session
     global engine
-    if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
+    from airflow.api_internal.internal_api_call import InternalApiConfig
+
+    if InternalApiConfig.get_use_internal_api():
+        Session = TracebackSession
+        engine = None
+        return
+    elif os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
         # Skip DB initialization in unit tests, if DB tests are skipped
         Session = SkipDBTestsSession
         engine = None
diff --git a/airflow/utils/session.py b/airflow/utils/session.py
index b3b610d199..2268b5fb61 100644
--- a/airflow/utils/session.py
+++ b/airflow/utils/session.py
@@ -24,12 +24,17 @@ from typing import Callable, Generator, TypeVar, cast
 from sqlalchemy.orm import Session as SASession
 
 from airflow import settings
+from airflow.api_internal.internal_api_call import InternalApiConfig
+from airflow.settings import TracebackSession
 from airflow.typing_compat import ParamSpec
 
 
 @contextlib.contextmanager
 def create_session() -> Generator[SASession, None, None]:
     """Contextmanager that will create and teardown a session."""
+    if InternalApiConfig.get_use_internal_api():
+        yield TracebackSession()
+        return
     Session = getattr(settings, "Session", None)
     if Session is None:
         raise RuntimeError("Session must be set before!")
diff --git a/tests/conftest.py b/tests/conftest.py
index 05bc72668d..6d102e7268 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -89,8 +89,6 @@ if skip_db_tests:
     # Make sure sqlalchemy will not be usable for pure unit tests even if 
initialized
     os.environ["AIRFLOW__CORE__SQL_ALCHEMY_CONN"] = "bad_schema:///"
     os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"] = "bad_schema:///"
-    # Force database isolation mode for pure unit tests
-    os.environ["AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION"] = "True"
     os.environ["_IN_UNIT_TESTS"] = "true"
     # Set it here to pass the flag to python-xdist spawned processes
     os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
diff --git a/tests/core/test_settings.py b/tests/core/test_settings.py
index 0cf38f5a44..abaf4a907b 100644
--- a/tests/core/test_settings.py
+++ b/tests/core/test_settings.py
@@ -26,7 +26,10 @@ from unittest.mock import MagicMock, call, patch
 
 import pytest
 
+from airflow.api_internal.internal_api_call import InternalApiConfig
 from airflow.exceptions import AirflowClusterPolicyViolation, 
AirflowConfigException
+from airflow.settings import _ENABLE_AIP_44, TracebackSession, configure_orm
+from airflow.utils.session import create_session
 from tests.test_utils.config import conf_vars
 
 SETTINGS_FILE_POLICY = """
@@ -264,3 +267,61 @@ class TestEngineArgs:
         engine_args = settings.prepare_engine_args()
 
         assert "encoding" not in engine_args
+
+
[email protected](not _ENABLE_AIP_44, reason="AIP-44 is disabled")
+@conf_vars(
+    {
+        ("core", "database_access_isolation"): "true",
+        ("core", "internal_api_url"): "http://localhost:8888";,
+    }
+)
+def test_get_traceback_session_if_aip_44_enabled():
+    # ensure we take the database_access_isolation config
+    InternalApiConfig._init_values()
+    assert InternalApiConfig.get_use_internal_api() is True
+
+    # ensure that the Session object is TracebackSession
+    configure_orm()
+
+    from airflow.settings import Session
+
+    assert Session == TracebackSession
+
+    # no error to create
+    with create_session() as session:
+        assert isinstance(session, TracebackSession)
+
+        with pytest.raises(
+            RuntimeError,
+            match="TracebackSession object was used but internal API is 
enabled.",
+        ):
+            session.hi()
+
+
[email protected](not _ENABLE_AIP_44, reason="AIP-44 is disabled")
+@conf_vars(
+    {
+        ("core", "database_access_isolation"): "true",
+        ("core", "internal_api_url"): "http://localhost:8888";,
+    }
+)
+@patch("airflow.utils.session.TracebackSession.__new__")
+def test_create_session_ctx_mgr_no_call_methods(mock_new):
+    m = MagicMock()
+    mock_new.return_value = m
+    # ensure we take the database_access_isolation config
+    InternalApiConfig._init_values()
+    assert InternalApiConfig.get_use_internal_api() is True
+
+    # ensure that the Session object is TracebackSession
+    configure_orm()
+
+    # no error to create
+    with create_session() as session:
+        assert isinstance(session, MagicMock)
+        assert session == m
+    method_calls = [x[0] for x in m.method_calls]
+    assert method_calls == []  # commit and close not called when using 
internal API
+
+    # assert mock_session_obj.call_args_list == []

Reply via email to