This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 581e2e42e9 Change AirflowTaskTimeout to inherit BaseException (#35653)
581e2e42e9 is described below

commit 581e2e42e947fc8f23ecccb89fbabccec9e8e26b
Author: HTErik <[email protected]>
AuthorDate: Wed Feb 21 18:43:17 2024 +0100

    Change AirflowTaskTimeout to inherit BaseException (#35653)
    
    Code that normally catches Exception should not implicitly ignore
    interrupts from AirflowTaskTimout.
    
    Fixes #35644 #35474
---
 airflow/exceptions.py                               |  5 ++++-
 airflow/models/taskinstance.py                      | 16 ++++++++--------
 .../celery/executors/celery_executor_utils.py       |  6 +++---
 airflow/utils/context.pyi                           |  2 +-
 newsfragments/35653.significant.rst                 | 21 +++++++++++++++++++++
 tests/core/test_core.py                             |  9 ++++++++-
 .../providers/microsoft/azure/hooks/test_synapse.py |  3 ++-
 7 files changed, 47 insertions(+), 15 deletions(-)

diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index f747640c77..f2fae6e8d4 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -79,7 +79,10 @@ class InvalidStatsNameException(AirflowException):
     """Raise when name of the stats is invalid."""
 
 
-class AirflowTaskTimeout(AirflowException):
+# Important to inherit BaseException instead of AirflowException->Exception, 
since this Exception is used
+# to explicitly interrupt ongoing task. Code that does normal error-handling 
should not treat
+# such interrupt as an error that can be handled normally. (Compare with 
KeyboardInterrupt)
+class AirflowTaskTimeout(BaseException):
     """Raise when the task execution times-out."""
 
 
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 5d026a0667..cf5c97922e 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -812,7 +812,7 @@ def _is_eligible_to_retry(*, task_instance: TaskInstance | 
TaskInstancePydantic)
 def _handle_failure(
     *,
     task_instance: TaskInstance | TaskInstancePydantic,
-    error: None | str | Exception | KeyboardInterrupt,
+    error: None | str | BaseException,
     session: Session,
     test_mode: bool | None = None,
     context: Context | None = None,
@@ -2411,7 +2411,7 @@ class TaskInstance(Base, LoggingMixin):
                 self.handle_failure(e, test_mode, context, force_fail=True, 
session=session)
                 session.commit()
                 raise
-            except AirflowException as e:
+            except (AirflowTaskTimeout, AirflowException) as e:
                 if not test_mode:
                     self.refresh_from_db(lock_for_update=True, session=session)
                 # for case when task is marked as success/failed externally
@@ -2426,10 +2426,6 @@ class TaskInstance(Base, LoggingMixin):
                     self.handle_failure(e, test_mode, context, session=session)
                     session.commit()
                     raise
-            except (Exception, KeyboardInterrupt) as e:
-                self.handle_failure(e, test_mode, context, session=session)
-                session.commit()
-                raise
             except SystemExit as e:
                 # We have already handled SystemExit with success codes (0 and 
None) in the `_execute_task`.
                 # Therefore, here we must handle only error codes.
@@ -2437,6 +2433,10 @@ class TaskInstance(Base, LoggingMixin):
                 self.handle_failure(msg, test_mode, context, session=session)
                 session.commit()
                 raise Exception(msg)
+            except BaseException as e:
+                self.handle_failure(e, test_mode, context, session=session)
+                session.commit()
+                raise
             finally:
                 
Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", 
tags=self.stats_tags)
                 # Same metric with tagging
@@ -2743,7 +2743,7 @@ class TaskInstance(Base, LoggingMixin):
     def fetch_handle_failure_context(
         cls,
         ti: TaskInstance | TaskInstancePydantic,
-        error: None | str | Exception | KeyboardInterrupt,
+        error: None | str | BaseException,
         test_mode: bool | None = None,
         context: Context | None = None,
         force_fail: bool = False,
@@ -2838,7 +2838,7 @@ class TaskInstance(Base, LoggingMixin):
     @provide_session
     def handle_failure(
         self,
-        error: None | str | Exception | KeyboardInterrupt,
+        error: None | str | BaseException,
         test_mode: bool | None = None,
         context: Context | None = None,
         force_fail: bool = False,
diff --git a/airflow/providers/celery/executors/celery_executor_utils.py 
b/airflow/providers/celery/executors/celery_executor_utils.py
index 292bbc0c70..bd1725e6d3 100644
--- a/airflow/providers/celery/executors/celery_executor_utils.py
+++ b/airflow/providers/celery/executors/celery_executor_utils.py
@@ -41,7 +41,7 @@ from sqlalchemy import select
 
 import airflow.settings as settings
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, AirflowTaskTimeout
 from airflow.executors.base_executor import BaseExecutor
 from airflow.stats import Stats
 from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
@@ -198,7 +198,7 @@ class ExceptionWithTraceback:
     :param exception_traceback: The stacktrace to wrap
     """
 
-    def __init__(self, exception: Exception, exception_traceback: str):
+    def __init__(self, exception: BaseException, exception_traceback: str):
         self.exception = exception
         self.traceback = exception_traceback
 
@@ -211,7 +211,7 @@ def send_task_to_executor(
     try:
         with timeout(seconds=OPERATION_TIMEOUT):
             result = task_to_run.apply_async(args=[command], queue=queue)
-    except Exception as e:
+    except (Exception, AirflowTaskTimeout) as e:
         exception_traceback = f"Celery Task ID: 
{key}\n{traceback.format_exc()}"
         result = ExceptionWithTraceback(e, exception_traceback)
 
diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi
index 256823dd0b..9fecccfb1d 100644
--- a/airflow/utils/context.pyi
+++ b/airflow/utils/context.pyi
@@ -65,7 +65,7 @@ class Context(TypedDict, total=False):
     data_interval_start: DateTime
     ds: str
     ds_nodash: str
-    exception: KeyboardInterrupt | Exception | str | None
+    exception: BaseException | str | None
     execution_date: DateTime
     expanded_ti_count: int | None
     inlets: list
diff --git a/newsfragments/35653.significant.rst 
b/newsfragments/35653.significant.rst
new file mode 100644
index 0000000000..ea93c83343
--- /dev/null
+++ b/newsfragments/35653.significant.rst
@@ -0,0 +1,21 @@
+``AirflowTimeoutError`` is no longer ``except``ed by default through 
``Exception``
+
+The ``AirflowTimeoutError`` is now inheriting ``BaseException`` instead of
+``AirflowException``->``Exception``.
+See https://docs.python.org/3/library/exceptions.html#exception-hierarchy
+
+This prevents code catching ``Exception`` from accidentally
+catching ``AirflowTimeoutError`` and continuing to run.
+``AirflowTimeoutError`` is an explicit intent to cancel the task, and should 
not
+be caught in attempts to handle the error and return some default value.
+
+Catching ``AirflowTimeoutError`` is still possible by explicitly ``except``ing
+``AirflowTimeoutError`` or ``BaseException``.
+This is discouraged, as it may allow the code to continue running even after
+such cancellation requests.
+Code that previously depended on performing strict cleanup in every situation
+after catching ``Exception`` is advised to use ``finally`` blocks or
+context managers. To perform only the cleanup and then automatically
+re-raise the exception.
+See similar considerations about catching ``KeyboardInterrupt`` in
+https://docs.python.org/3/library/exceptions.html#KeyboardInterrupt
diff --git a/tests/core/test_core.py b/tests/core/test_core.py
index 5f37cb2db0..c687a352bd 100644
--- a/tests/core/test_core.py
+++ b/tests/core/test_core.py
@@ -71,11 +71,18 @@ class TestCore:
             op.dry_run()
 
     def test_timeout(self, dag_maker):
+        def sleep_and_catch_other_exceptions():
+            try:
+                sleep(5)
+                # Catching Exception should NOT catch AirflowTaskTimeout
+            except Exception:
+                pass
+
         with dag_maker():
             op = PythonOperator(
                 task_id="test_timeout",
                 execution_timeout=timedelta(seconds=1),
-                python_callable=lambda: sleep(5),
+                python_callable=sleep_and_catch_other_exceptions,
             )
         dag_maker.create_dagrun()
         with pytest.raises(AirflowTaskTimeout):
diff --git a/tests/providers/microsoft/azure/hooks/test_synapse.py 
b/tests/providers/microsoft/azure/hooks/test_synapse.py
index d66268798d..9b116cd054 100644
--- a/tests/providers/microsoft/azure/hooks/test_synapse.py
+++ b/tests/providers/microsoft/azure/hooks/test_synapse.py
@@ -21,6 +21,7 @@ from unittest.mock import MagicMock, patch
 import pytest
 from azure.synapse.spark import SparkClient
 
+from airflow.exceptions import AirflowTaskTimeout
 from airflow.models.connection import Connection
 from airflow.providers.microsoft.azure.hooks.synapse import AzureSynapseHook, 
AzureSynapseSparkBatchRunStatus
 
@@ -172,7 +173,7 @@ def test_wait_for_job_run_status(hook, job_run_status, 
expected_status, expected
         if expected_output != "timeout":
             assert hook.wait_for_job_run_status(**config) == expected_output
         else:
-            with pytest.raises(Exception):
+            with pytest.raises(AirflowTaskTimeout):
                 hook.wait_for_job_run_status(**config)
 
 

Reply via email to