This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 04c2ab5be6 Add session that blows up when using internal API (#38563)
04c2ab5be6 is described below
commit 04c2ab5be669550e4c4d1d004ed1fd1461e58f7e
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Apr 9 07:55:24 2024 -0700
Add session that blows up when using internal API (#38563)
Here I add a TracebackSession which is designed to blow up the first time
it's used and then provide a traceback for the original call site. I also
ensure that this is the Session class that is used when internal API config is
enabled.
---
airflow/settings.py | 39 +++++++++++++++++++++++++++--
airflow/utils/session.py | 5 ++++
tests/conftest.py | 2 --
tests/core/test_settings.py | 61 +++++++++++++++++++++++++++++++++++++++++++++
4 files changed, 103 insertions(+), 4 deletions(-)
diff --git a/airflow/settings.py b/airflow/settings.py
index 7ead89f34e..fa859fdf42 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -23,6 +23,7 @@ import json
import logging
import os
import sys
+import traceback
import warnings
from typing import TYPE_CHECKING, Any, Callable
@@ -207,7 +208,11 @@ def configure_vars():
class SkipDBTestsSession:
- """This fake session is used to skip DB tests when
`_AIRFLOW_SKIP_DB_TESTS` is set."""
+ """
+ This fake session is used to skip DB tests when `_AIRFLOW_SKIP_DB_TESTS`
is set.
+
+ :meta private:
+ """
def __init__(self):
raise AirflowInternalRuntimeError(
@@ -222,6 +227,30 @@ class SkipDBTestsSession:
pass
+class TracebackSession:
+ """
+ Session that throws error when you try to use it.
+
+ Also stores stack at instantiation call site.
+
+ :meta private:
+ """
+
+ def __init__(self):
+ self.traceback = traceback.extract_stack()
+
+ def __getattr__(self, item):
+ raise RuntimeError(
+ "TracebackSession object was used but internal API is enabled. "
+ "You'll need to ensure you are making only RPC calls with this
object. "
+ "The stack list below will show where the TracebackSession object
was created."
+ + "\n".join(traceback.format_list(self.traceback))
+ )
+
+ def remove(*args, **kwargs):
+ pass
+
+
def configure_orm(disable_connection_pool=False, pool_class=None):
"""Configure ORM using SQLAlchemy."""
from airflow.utils.log.secrets_masker import mask_secret
@@ -242,7 +271,13 @@ def configure_orm(disable_connection_pool=False,
pool_class=None):
global Session
global engine
- if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
+ from airflow.api_internal.internal_api_call import InternalApiConfig
+
+ if InternalApiConfig.get_use_internal_api():
+ Session = TracebackSession
+ engine = None
+ return
+ elif os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
# Skip DB initialization in unit tests, if DB tests are skipped
Session = SkipDBTestsSession
engine = None
diff --git a/airflow/utils/session.py b/airflow/utils/session.py
index b3b610d199..2268b5fb61 100644
--- a/airflow/utils/session.py
+++ b/airflow/utils/session.py
@@ -24,12 +24,17 @@ from typing import Callable, Generator, TypeVar, cast
from sqlalchemy.orm import Session as SASession
from airflow import settings
+from airflow.api_internal.internal_api_call import InternalApiConfig
+from airflow.settings import TracebackSession
from airflow.typing_compat import ParamSpec
@contextlib.contextmanager
def create_session() -> Generator[SASession, None, None]:
"""Contextmanager that will create and teardown a session."""
+ if InternalApiConfig.get_use_internal_api():
+ yield TracebackSession()
+ return
Session = getattr(settings, "Session", None)
if Session is None:
raise RuntimeError("Session must be set before!")
diff --git a/tests/conftest.py b/tests/conftest.py
index 05bc72668d..6d102e7268 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -89,8 +89,6 @@ if skip_db_tests:
# Make sure sqlalchemy will not be usable for pure unit tests even if
initialized
os.environ["AIRFLOW__CORE__SQL_ALCHEMY_CONN"] = "bad_schema:///"
os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"] = "bad_schema:///"
- # Force database isolation mode for pure unit tests
- os.environ["AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION"] = "True"
os.environ["_IN_UNIT_TESTS"] = "true"
# Set it here to pass the flag to python-xdist spawned processes
os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
diff --git a/tests/core/test_settings.py b/tests/core/test_settings.py
index 0cf38f5a44..abaf4a907b 100644
--- a/tests/core/test_settings.py
+++ b/tests/core/test_settings.py
@@ -26,7 +26,10 @@ from unittest.mock import MagicMock, call, patch
import pytest
+from airflow.api_internal.internal_api_call import InternalApiConfig
from airflow.exceptions import AirflowClusterPolicyViolation,
AirflowConfigException
+from airflow.settings import _ENABLE_AIP_44, TracebackSession, configure_orm
+from airflow.utils.session import create_session
from tests.test_utils.config import conf_vars
SETTINGS_FILE_POLICY = """
@@ -264,3 +267,61 @@ class TestEngineArgs:
engine_args = settings.prepare_engine_args()
assert "encoding" not in engine_args
+
+
[email protected](not _ENABLE_AIP_44, reason="AIP-44 is disabled")
+@conf_vars(
+ {
+ ("core", "database_access_isolation"): "true",
+ ("core", "internal_api_url"): "http://localhost:8888",
+ }
+)
+def test_get_traceback_session_if_aip_44_enabled():
+ # ensure we take the database_access_isolation config
+ InternalApiConfig._init_values()
+ assert InternalApiConfig.get_use_internal_api() is True
+
+ # ensure that the Session object is TracebackSession
+ configure_orm()
+
+ from airflow.settings import Session
+
+ assert Session == TracebackSession
+
+ # no error to create
+ with create_session() as session:
+ assert isinstance(session, TracebackSession)
+
+ with pytest.raises(
+ RuntimeError,
+ match="TracebackSession object was used but internal API is
enabled.",
+ ):
+ session.hi()
+
+
[email protected](not _ENABLE_AIP_44, reason="AIP-44 is disabled")
+@conf_vars(
+ {
+ ("core", "database_access_isolation"): "true",
+ ("core", "internal_api_url"): "http://localhost:8888",
+ }
+)
+@patch("airflow.utils.session.TracebackSession.__new__")
+def test_create_session_ctx_mgr_no_call_methods(mock_new):
+ m = MagicMock()
+ mock_new.return_value = m
+ # ensure we take the database_access_isolation config
+ InternalApiConfig._init_values()
+ assert InternalApiConfig.get_use_internal_api() is True
+
+ # ensure that the Session object is TracebackSession
+ configure_orm()
+
+ # no error to create
+ with create_session() as session:
+ assert isinstance(session, MagicMock)
+ assert session == m
+ method_calls = [x[0] for x in m.method_calls]
+ assert method_calls == [] # commit and close not called when using
internal API
+
+ # assert mock_session_obj.call_args_list == []