This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 745898a17e5 Fail fast for non-serializable retry_args in deferrable 
operators and triggers (#64960)
745898a17e5 is described below

commit 745898a17e5faadeb929b5cdee8deba0c0cea819
Author: kosiew <[email protected]>
AuthorDate: Tue Jun 2 19:57:54 2026 +0800

    Fail fast for non-serializable retry_args in deferrable operators and 
triggers (#64960)
    
    * Add validation for non-serializable retry_args
    
    Implement a shared validation guard to reject
    non-serializable databricks_retry_args before
    deferrable Databricks tasks cross the trigger boundary.
    Enforce this check for deferrable operators and SQL
    sensor in databricks.py. Add regression tests to cover
    failure modes for both in test_databricks.py.
    
    * Refactor validation to utility module and enhance tests
    
    Move validation logic to retry.py for better cohesion. Enforce
    validation in both trigger constructors within databricks.py.
    Add direct trigger regression tests in test_databricks.py and
    update sensor test setup to maintain deferrable branch coverage.
    
    * Add parameterized validation tests for Tenacity shapes
    
    Enhance operators, sensors, and triggers tests to cover two
    unsupported Tenacity shapes. Tests are now parameterized for
    {"wait": wait_incrementing(...)} and {"stop":
    stop_after_attempt(...)} scenarios.
    
    * Refactor retry argument tests and reduce duplication
    
    Extract shared invalid retry-arg test data and pytest.raises
    assertion into _retry_test_utils.py. Remove duplicated
    UNSUPPORTED_RETRY_ARGS definitions from operator, sensor, and
    trigger test files. Simplify setup in operator and sensor
    negative tests with local helpers for the running deferrable
    path. Combine two trigger-construction negative tests into
    one shared parametrized test in test_databricks.py.
    
    * Tighten API and update retry tests
    
    Require owner explicitly in retry.py's private helper.
    Define an UNSUPPORTED_RETRY_ARGS constant in
    _retry_test_utils.py and update operator, sensor, and
    trigger tests to parametrize directly from it in
    test_databricks.py.
    
    * Update retry logic and Databricks tests
    
    Refactor retry.py to catch ValueErrors and clarify
    retry_args/databricks_retry_args messages. Adjust
    validation in databricks.py to use owner=caller. Update
    tests in operators, sensors, and triggers for
    Databricks. Fix test-helper import to follow repo style.
    
    * Refactor retry.py to use stdlib JSON serialization
    
    Replace SDK serde import with stdlib JSON serialization
    in retry.py. Update validation call to use json.dumps()
    instead of serde_serialize() to improve simplicity and
    reduce dependencies.
    
    * Add unit test for validate_deferrable_databricks_retry_args
    
    Implement tests for the retry validation function in the
    Databricks provider. Handle cases for `None` and valid
    JSON-serializable primitive retry configurations, while
    ensuring unsupported Tenacity retry arguments are rejected.
    
    * Add dev/databricks_retry_args_repro.py
    
    * rm dev/databricks_retry_args_repro.py
    
    * trigger ci
    
    * feat(tests): refactor retry test utilities and improve assertions
    
    - Removed unnecessary retry test utility file.
    - Moved retry test constants into a more appropriate location.
    - Inlined retry error assertions in sensor and trigger tests for clarity.
    - Added explicit assertions for success validation tests to enhance 
reliability.
    - Added comments to trigger constructor for better understanding of 
serialization-boundary fail-fast validation.
    
    * chore(databricks): update retry utility to use airflow.sdk.serde.serialize
    
    - Modified retry.py to utilize airflow.sdk.serde.serialize
    - Retained wrapping of AttributeError, RecursionError, TypeError, and 
ValueError
    - Updated error message to indicate "Airflow-serializable"
    - Enhanced test_retry.py with datetime coverage for serde-supported 
non-JSON values
    - Kept Tenacity rejection tests unchanged
    
    * trigger ci
    
    * fix: replace datetime.UTC with datetime.timezone.utc in test_retry.py
    
    * feat(databricks): add compat import fallback for serialization module in 
retry.py
    
    * feat(databricks): enhance retry utility with improved serialization and 
cleanup
    
    - Removed duplicate fallback import names for clarity.
    - Added `get_serde_serialize()` function utilizing `import_module(...)`.
    - Updated validator to call `get_serde_serialize()(retry_args)`.
---
 .../providers/databricks/triggers/databricks.py    |  7 +++
 .../airflow/providers/databricks/utils/retry.py    | 43 ++++++++++++++++
 .../unit/databricks/operators/test_databricks.py   | 28 ++++++++++
 .../unit/databricks/sensors/test_databricks.py     | 29 +++++++++++
 .../unit/databricks/triggers/test_databricks.py    | 39 +++++++++++++-
 .../tests/unit/databricks/utils/test_retry.py      | 60 ++++++++++++++++++++++
 6 files changed, 204 insertions(+), 2 deletions(-)

diff --git 
a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py 
b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py
index 25cade7fc80..2bb626b9114 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py
@@ -23,6 +23,7 @@ from typing import Any
 
 from airflow.providers.databricks.hooks.databricks import DatabricksHook
 from airflow.providers.databricks.utils.databricks import 
extract_failed_task_errors_async
+from airflow.providers.databricks.utils.retry import 
validate_deferrable_databricks_retry_args
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 
@@ -55,6 +56,9 @@ class DatabricksExecutionTrigger(BaseTrigger):
         caller: str = "DatabricksExecutionTrigger",
     ) -> None:
         super().__init__()
+        # Trigger kwargs cross Airflow's serialization boundary, so fail 
before storing invalid
+        # trigger state or surfacing a generic serializer error without 
Databricks-specific guidance.
+        validate_deferrable_databricks_retry_args(retry_args, owner=caller)
         self.run_id = run_id
         self.databricks_conn_id = databricks_conn_id
         self.polling_period_seconds = polling_period_seconds
@@ -151,6 +155,9 @@ class DatabricksSQLStatementExecutionTrigger(BaseTrigger):
         caller: str = "DatabricksSQLStatementExecutionTrigger",
     ) -> None:
         super().__init__()
+        # Trigger kwargs cross Airflow's serialization boundary, so fail 
before storing invalid
+        # trigger state or surfacing a generic serializer error without 
Databricks-specific guidance.
+        validate_deferrable_databricks_retry_args(retry_args, owner=caller)
         self.statement_id = statement_id
         self.databricks_conn_id = databricks_conn_id
         self.end_time = end_time
diff --git 
a/providers/databricks/src/airflow/providers/databricks/utils/retry.py 
b/providers/databricks/src/airflow/providers/databricks/utils/retry.py
new file mode 100644
index 00000000000..508f4b71628
--- /dev/null
+++ b/providers/databricks/src/airflow/providers/databricks/utils/retry.py
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from collections.abc import Callable, Mapping
+from importlib import import_module
+from typing import Any
+
+
+def get_serde_serialize() -> Callable[[Any], Any]:
+    try:
+        return import_module("airflow.sdk.serde").serialize
+    except ImportError:
+        return import_module("airflow.serialization.serde").serialize
+
+
+def validate_deferrable_databricks_retry_args(retry_args: Mapping[str, Any] | 
None, *, owner: str) -> None:
+    """Validate retry args that need to cross the trigger serialization 
boundary."""
+    if retry_args is None:
+        return
+
+    try:
+        get_serde_serialize()(retry_args)
+    except (AttributeError, RecursionError, TypeError, ValueError) as err:
+        raise ValueError(
+            f"{owner} does not support non-serializable 
retry_args/databricks_retry_args "
+            "when deferrable=True. "
+            "Use Airflow-serializable values, remove callable retry 
strategies, or disable deferrable mode."
+        ) from err
diff --git 
a/providers/databricks/tests/unit/databricks/operators/test_databricks.py 
b/providers/databricks/tests/unit/databricks/operators/test_databricks.py
index 986cf51f1cf..a1b7b4f11b3 100644
--- a/providers/databricks/tests/unit/databricks/operators/test_databricks.py
+++ b/providers/databricks/tests/unit/databricks/operators/test_databricks.py
@@ -24,6 +24,7 @@ from unittest import mock
 from unittest.mock import MagicMock, call
 
 import pytest
+from tenacity import stop_after_attempt, wait_incrementing
 
 # Do not run the tests when FAB / Flask is not installed
 pytest.importorskip("flask_session")
@@ -95,6 +96,13 @@ TAGS = {
     "cost-center": "engineering",
     "team": "jobs",
 }
+INVALID_RETRY_ARGS_PATTERN = (
+    "does not support non-serializable retry_args/databricks_retry_args when 
deferrable=True"
+)
+UNSUPPORTED_RETRY_ARGS = [
+    pytest.param({"wait": wait_incrementing(start=1, increment=1, max=3)}, 
id="wait_incrementing"),
+    pytest.param({"stop": stop_after_attempt(3)}, id="stop_after_attempt"),
+]
 TASKS = [
     {
         "task_key": "Sessionize",
@@ -666,6 +674,11 @@ class TestDatabricksCreateJobsOperator:
 
 
 class TestDatabricksSubmitRunOperator:
+    @staticmethod
+    def _configure_running_deferrable_hook(db_mock):
+        db_mock.submit_run.return_value = RUN_ID
+        db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING")
+
     def test_init_with_notebook_task_named_parameters(self):
         """
         Test the initializer with the named parameters.
@@ -1089,6 +1102,21 @@ class TestDatabricksSubmitRunOperator:
         db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
         assert op.run_id == RUN_ID
 
+    @pytest.mark.parametrize("retry_args", UNSUPPORTED_RETRY_ARGS)
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    def test_execute_task_deferred_rejects_non_serializable_retry_args(self, 
db_mock_class, retry_args):
+        op = DatabricksSubmitRunOperator(
+            deferrable=True,
+            task_id=TASK_ID,
+            json={"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK},
+            databricks_retry_args=retry_args,
+        )
+        db_mock = db_mock_class.return_value
+        self._configure_running_deferrable_hook(db_mock)
+
+        with pytest.raises(ValueError, match=INVALID_RETRY_ARGS_PATTERN):
+            op.execute(None)
+
     def test_execute_complete_success(self):
         """
         Test `execute_complete` function in case the Trigger has returned a 
successful completion event.
diff --git 
a/providers/databricks/tests/unit/databricks/sensors/test_databricks.py 
b/providers/databricks/tests/unit/databricks/sensors/test_databricks.py
index 615364f89ed..8f94274e7a4 100644
--- a/providers/databricks/tests/unit/databricks/sensors/test_databricks.py
+++ b/providers/databricks/tests/unit/databricks/sensors/test_databricks.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 from unittest import mock
 
 import pytest
+from tenacity import stop_after_attempt, wait_incrementing
 
 from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred
 from airflow.providers.databricks.hooks.databricks import SQLStatementState
@@ -31,6 +32,13 @@ STATEMENT = "select * from test.test;"
 STATEMENT_ID = "statement_id"
 TASK_ID = "task_id"
 WAREHOUSE_ID = "warehouse_id"
+INVALID_RETRY_ARGS_PATTERN = (
+    "does not support non-serializable retry_args/databricks_retry_args when 
deferrable=True"
+)
+UNSUPPORTED_RETRY_ARGS = [
+    pytest.param({"wait": wait_incrementing(start=1, increment=1, max=3)}, 
id="wait_incrementing"),
+    pytest.param({"stop": stop_after_attempt(3)}, id="stop_after_attempt"),
+]
 
 
 class TestDatabricksSQLStatementsSensor:
@@ -39,6 +47,11 @@ class TestDatabricksSQLStatementsSensor:
     from the DatabricksSQLStatementOperator, meaning that much of the testing 
logic is also reused.
     """
 
+    @staticmethod
+    def _configure_running_deferrable_hook(db_mock):
+        db_mock.post_sql_statement.return_value = STATEMENT_ID
+        db_mock.get_sql_statement_state.return_value = 
SQLStatementState("RUNNING")
+
     def test_init_statement(self):
         """Test initialization for traditional use-case (statement)."""
         op = DatabricksSQLStatementsSensor(task_id=TASK_ID, 
statement=STATEMENT, warehouse_id=WAREHOUSE_ID)
@@ -167,6 +180,22 @@ class TestDatabricksSQLStatementsSensor:
         assert isinstance(exc.value.trigger, 
DatabricksSQLStatementExecutionTrigger)
         assert exc.value.method_name == "execute_complete"
 
+    @pytest.mark.parametrize("retry_args", UNSUPPORTED_RETRY_ARGS)
+    
@mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook")
+    def test_execute_task_deferred_rejects_non_serializable_retry_args(self, 
db_mock_class, retry_args):
+        op = DatabricksSQLStatementsSensor(
+            task_id=TASK_ID,
+            statement=STATEMENT,
+            warehouse_id=WAREHOUSE_ID,
+            deferrable=True,
+            databricks_retry_args=retry_args,
+        )
+        db_mock = db_mock_class.return_value
+        self._configure_running_deferrable_hook(db_mock)
+
+        with pytest.raises(ValueError, match=INVALID_RETRY_ARGS_PATTERN):
+            op.execute(None)
+
     def test_execute_complete_success(self):
         """
         Test the execute_complete function in case the Trigger has returned a 
successful completion event.
diff --git 
a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py 
b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py
index 903173774b7..8854eb03fb5 100644
--- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py
+++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py
@@ -17,10 +17,10 @@
 # under the License.
 from __future__ import annotations
 
-import time
 from unittest import mock
 
 import pytest
+from tenacity import stop_after_attempt, wait_incrementing
 
 from airflow.models import Connection
 from airflow.providers.databricks.hooks.databricks import RunState, 
SQLStatementState
@@ -42,6 +42,7 @@ RETRY_DELAY = 10
 RETRY_LIMIT = 3
 RUN_ID = 1
 STATEMENT_ID = "statement_id"
+STATEMENT_END_TIME = 9999999999.0
 TASK_RUN_ID1 = 11
 TASK_RUN_ID1_KEY = "first_task"
 TASK_RUN_ID2 = 22
@@ -53,6 +54,13 @@ RUN_PAGE_URL = 
"https://XX.cloud.databricks.com/#jobs/1/runs/1";
 CALLER = "DatabricksSubmitRunOperator"
 ERROR_MESSAGE = "error message from databricks API"
 GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE, 
"notebook_output": {}}
+INVALID_RETRY_ARGS_PATTERN = (
+    "does not support non-serializable retry_args/databricks_retry_args when 
deferrable=True"
+)
+UNSUPPORTED_RETRY_ARGS = [
+    pytest.param({"wait": wait_incrementing(start=1, increment=1, max=3)}, 
id="wait_incrementing"),
+    pytest.param({"stop": stop_after_attempt(3)}, id="stop_after_attempt"),
+]
 
 RUN_LIFE_CYCLE_STATES = ["PENDING", "RUNNING", "TERMINATING", "TERMINATED", 
"SKIPPED", "INTERNAL_ERROR"]
 
@@ -119,6 +127,33 @@ GET_RUN_RESPONSE_TERMINATED_WITH_FAILED = {
     ],
 }
 
+TRIGGER_INIT_CASES = [
+    pytest.param(
+        DatabricksExecutionTrigger,
+        {
+            "run_id": RUN_ID,
+            "databricks_conn_id": DEFAULT_CONN_ID,
+        },
+        id="execution_trigger",
+    ),
+    pytest.param(
+        DatabricksSQLStatementExecutionTrigger,
+        {
+            "statement_id": STATEMENT_ID,
+            "databricks_conn_id": DEFAULT_CONN_ID,
+            "end_time": 1234567890.0,
+        },
+        id="sql_statement_trigger",
+    ),
+]
+
+
[email protected]("retry_args", UNSUPPORTED_RETRY_ARGS)
[email protected](("trigger_cls", "trigger_kwargs"), TRIGGER_INIT_CASES)
+def test_trigger_init_rejects_non_serializable_retry_args(trigger_cls, 
trigger_kwargs, retry_args):
+    with pytest.raises(ValueError, match=INVALID_RETRY_ARGS_PATTERN):
+        trigger_cls(**trigger_kwargs, retry_args=retry_args)
+
 
 class TestDatabricksExecutionTrigger:
     @pytest.fixture(autouse=True)
@@ -281,7 +316,7 @@ class TestDatabricksExecutionTrigger:
 class TestDatabricksSQLStatementExecutionTrigger:
     @pytest.fixture(autouse=True)
     def setup_connections(self, create_connection_without_db):
-        self.end_time = time.time() + 60
+        self.end_time = STATEMENT_END_TIME
         create_connection_without_db(
             Connection(
                 conn_id=DEFAULT_CONN_ID,
diff --git a/providers/databricks/tests/unit/databricks/utils/test_retry.py 
b/providers/databricks/tests/unit/databricks/utils/test_retry.py
new file mode 100644
index 00000000000..7ff46a71f7d
--- /dev/null
+++ b/providers/databricks/tests/unit/databricks/utils/test_retry.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import datetime
+
+import pytest
+from tenacity import stop_after_attempt, wait_incrementing
+
+from airflow.providers.databricks.utils.retry import 
validate_deferrable_databricks_retry_args
+
+INVALID_RETRY_ARGS_PATTERN = (
+    "does not support non-serializable retry_args/databricks_retry_args when 
deferrable=True"
+)
+UNSUPPORTED_RETRY_ARGS = [
+    pytest.param({"wait": wait_incrementing(start=1, increment=1, max=3)}, 
id="wait_incrementing"),
+    pytest.param({"stop": stop_after_attempt(3)}, id="stop_after_attempt"),
+]
+
+
+def test_validate_deferrable_databricks_retry_args_accepts_none():
+    assert validate_deferrable_databricks_retry_args(None, owner="test-owner") 
is None
+
+
[email protected](
+    "retry_args",
+    [
+        {},
+        {"retry_limit": 3, "retry_delay": 10},
+        {"retry_limit": 3, "retry_delay": 10.5, "retry_enabled": True, 
"retry_codes": ["429", "500"]},
+    ],
+)
+def 
test_validate_deferrable_databricks_retry_args_accepts_serde_serializable_values(retry_args):
+    assert validate_deferrable_databricks_retry_args(retry_args, 
owner="test-owner") is None
+
+
+def 
test_validate_deferrable_databricks_retry_args_accepts_airflow_serde_serializable_values():
+    retry_args = {"deadline": datetime.datetime(2026, 5, 29, 12, 30, 
tzinfo=datetime.timezone.utc)}
+
+    assert validate_deferrable_databricks_retry_args(retry_args, 
owner="test-owner") is None
+
+
[email protected]("retry_args", UNSUPPORTED_RETRY_ARGS)
+def 
test_validate_deferrable_databricks_retry_args_rejects_non_serializable_values(retry_args):
+    with pytest.raises(ValueError, match=INVALID_RETRY_ARGS_PATTERN):
+        validate_deferrable_databricks_retry_args(retry_args, 
owner="test-owner")

Reply via email to