This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e70337baa Improve detection of the airflow code vs. test code in db 
isolation (#41127)
4e70337baa is described below

commit 4e70337baa96476a8b6f13e136d6c8a9a4f8f3a1
Author: Jarek Potiuk <[email protected]>
AuthorDate: Tue Jul 30 20:30:25 2024 +0200

    Improve detection of the airflow code vs. test code in db isolation (#41127)
    
    The current code has been only checking when the session is created
    in tests - to see if the aiflow code is not using the session,
    However, Sometimes when session is created in the tests and passed
    to "airflow" code (say running a task) - the session can be
    passed directly and it does not matter where the session is
    created but where it is used. This PR implements checking of
    that - and handles better the edge cases.
    
    The message is also improved in this case - we show the exact
    case where the session was used in airflow, on top of the
    whole stacktrace so you should be able to see easily what
    the problem is.
    
    Related: #41067
---
 airflow/settings.py | 37 ++++++++++++++++++++++++++-----------
 1 file changed, 26 insertions(+), 11 deletions(-)

diff --git a/airflow/settings.py b/airflow/settings.py
index eb4053f50e..e1915ab807 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -293,6 +293,7 @@ class TracebackSession:
 
 AIRFLOW_PATH = os.path.dirname(os.path.dirname(__file__))
 AIRFLOW_TESTS_PATH = os.path.join(AIRFLOW_PATH, "tests")
+AIRFLOW_SETTINGS_PATH = os.path.join(AIRFLOW_PATH, "airflow", "settings.py")
 
 
 class TracebackSessionForTests:
@@ -309,26 +310,31 @@ class TracebackSessionForTests:
     db_session_class = None
 
     def __init__(self):
-        self.traceback = traceback.extract_stack()
         self.current_db_session = TracebackSessionForTests.db_session_class()
+        self.created_traceback = traceback.extract_stack()
 
     def __getattr__(self, item):
-        if self.is_called_from_test_code():
+        test_code, frame_summary = self.is_called_from_test_code()
+        if test_code:
             return getattr(self.current_db_session, item)
         raise RuntimeError(
             "TracebackSessionForTests object was used but internal API is 
enabled. "
-            "Only test code is allowed to use this object. "
+            "Only test code is allowed to use this object.\n"
+            f"Called from:\n    {frame_summary.filename}: 
{frame_summary.lineno}{frame_summary.colno}\n"
+            f"     {frame_summary.line}\n\n"
             "You'll need to ensure you are making only RPC calls with this 
object. "
-            "The stack list below will show where the TracebackSession object 
was created."
-            + "\n".join(traceback.format_list(self.traceback))
+            "The stack list below will show where the TracebackSession object 
was called:\n"
+            + "".join(traceback.format_list(self.traceback))
+            + "\n\nThe stack list below will show where the TracebackSession 
object was created:\n"
+            + "".join(traceback.format_list(self.created_traceback))
         )
 
     def remove(*args, **kwargs):
         pass
 
-    def is_called_from_test_code(self) -> bool:
+    def is_called_from_test_code(self) -> tuple[bool, traceback.FrameSummary | 
None]:
         """
-        Check if the object was created from test code.
+        Check if the traceback session was used from the test code.
 
         This is done by checking if the first "airflow" filename in the 
traceback
         is "airflow/tests" or "regular airflow".
@@ -336,12 +342,21 @@ class TracebackSessionForTests:
         :meta: private
         :return: True if the object was created from test code, False 
otherwise.
         """
-        for tb in self.traceback:
+        self.traceback = traceback.extract_stack()
+        if any(filename.endswith("conftest.py") for filename, _, _, _ in 
self.traceback):
+            return True, None
+        for tb in self.traceback[::-1]:
+            # Skip first two settings.py file (will be always here - because 
we call it from here
+            if tb.filename == AIRFLOW_SETTINGS_PATH:
+                continue
             if tb.filename.startswith(AIRFLOW_PATH):
-                # if this is the also "test" code, we are good, otherwise we 
are in Airflow code
-                return tb.filename.startswith(AIRFLOW_TESTS_PATH)
+                if tb.filename.startswith(AIRFLOW_TESTS_PATH):
+                    return True, None
+                else:
+                    return False, tb
         # if it is from elsewhere.... Why???? We should return False in order 
to crash to find out
-        return False
+        # The traceback line will be always 3rd (two bottom ones are Airflow)
+        return False, self.traceback[-2]
 
 
 def _is_sqlite_db_path_relative(sqla_conn_str: str) -> bool:

Reply via email to