This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 853571c1970 fix: allow deadline callbacks within the same dag module
(#66702)
853571c1970 is described below
commit 853571c19701a3a2727530d11009160fcdc5849d
Author: Sebastian Daum <[email protected]>
AuthorDate: Sun Jun 7 21:48:41 2026 +0200
fix: allow deadline callbacks within the same dag module (#66702)
---
.../src/airflow/executors/base_executor.py | 1 +
airflow-core/src/airflow/utils/file.py | 3 +-
.../tests/unit/executors/test_base_executor.py | 114 ++++++++++++++-
.../tests/unit/executors/test_local_executor.py | 6 +-
.../src/airflow_shared/module_loading/__init__.py | 4 +
.../src/airflow_shared/module_loading/dag_file.py | 23 +++
.../tests/module_loading/test_dag_file.py | 26 ++++
.../sdk/execution_time/callback_supervisor.py | 40 +++++-
.../execution_time/test_callback_supervisor.py | 157 ++++++++++++++++++++-
9 files changed, 355 insertions(+), 19 deletions(-)
diff --git a/airflow-core/src/airflow/executors/base_executor.py
b/airflow-core/src/airflow/executors/base_executor.py
index f708638999e..31c307b163f 100644
--- a/airflow-core/src/airflow/executors/base_executor.py
+++ b/airflow-core/src/airflow/executors/base_executor.py
@@ -717,6 +717,7 @@ class BaseExecutor(LoggingMixin):
id=workload.callback.id,
callback_path=workload.callback.data.get("path", ""),
callback_kwargs=workload.callback.data.get("kwargs", {}),
+ dag_rel_path=workload.dag_rel_path,
log_path=workload.log_path,
bundle_info=workload.bundle_info,
token=workload.token,
diff --git a/airflow-core/src/airflow/utils/file.py
b/airflow-core/src/airflow/utils/file.py
index c614cfff0ad..feeaa5239c3 100644
--- a/airflow-core/src/airflow/utils/file.py
+++ b/airflow-core/src/airflow/utils/file.py
@@ -28,12 +28,11 @@ from io import TextIOWrapper
from pathlib import Path
from typing import overload
+from airflow._shared.module_loading import MODIFIED_DAG_MODULE_NAME
from airflow.configuration import conf
log = logging.getLogger(__name__)
-MODIFIED_DAG_MODULE_NAME = "unusual_prefix_{path_hash}_{module_name}"
-
ZIP_REGEX = re.compile(rf"((.*\.zip){re.escape(os.sep)})?(.*)")
diff --git a/airflow-core/tests/unit/executors/test_base_executor.py
b/airflow-core/tests/unit/executors/test_base_executor.py
index 899eb8f584c..9c45c5056ca 100644
--- a/airflow-core/tests/unit/executors/test_base_executor.py
+++ b/airflow-core/tests/unit/executors/test_base_executor.py
@@ -19,6 +19,8 @@ from __future__ import annotations
import logging
from datetime import timedelta
+from pathlib import Path
+from textwrap import dedent
from unittest import mock
from uuid import UUID, uuid4
@@ -725,20 +727,118 @@ class TestCallbackSupport:
class TestExecuteCallbackWorkload:
@pytest.mark.parametrize(
- ("path", "kwargs", "expect_success", "error_contains"),
+ ("path", "kwargs", "dag_rel_path", "bundle_path", "expect_success",
"error_contains"),
[
- pytest.param("builtins.dict", {"a": 1, "b": 2, "c": 3}, True,
None, id="function_success"),
- pytest.param("", {}, False, "Callback path not found",
id="missing_path"),
- pytest.param("nonexistent.module.function", {}, False,
"ModuleNotFoundError", id="import_error"),
- pytest.param("builtins.len", {}, False, "TypeError",
id="execution_error"),
+ pytest.param(
+ "builtins.dict",
+ {"a": 1, "b": 2, "c": 3},
+ Path("test.py"),
+ Path("bundle/path"),
+ True,
+ None,
+ id="function_success",
+ ),
+ pytest.param(
+ "",
+ {},
+ Path("test.py"),
+ Path("bundle/path"),
+ False,
+ "Callback path not found",
+ id="missing_path",
+ ),
+ pytest.param(
+ "nonexistent.module.function",
+ {},
+ Path("test.py"),
+ Path("bundle/path"),
+ False,
+ "ModuleNotFoundError",
+ id="import_error",
+ ),
+ pytest.param(
+ "builtins.len",
+ {},
+ Path("test.py"),
+ Path("bundle/path"),
+ False,
+ "TypeError",
+ id="execution_error",
+ ),
+ pytest.param(
+
"unusual_prefix_fad099f9df8ac798a50aac7381aab95ad4008e79_test_dag.success_message",
+ {},
+ Path("test.py"),
+ Path("bundle/path"),
+ False,
+ "FileNotFoundError",
+ id="dag_import_error",
+ ),
],
)
- def test_execute_callback(self, path, kwargs, expect_success,
error_contains):
+ def test_execute_callback(self, path, kwargs, dag_rel_path, bundle_path,
expect_success, error_contains):
log = structlog.get_logger()
- success, error = execute_callback(path, kwargs, log)
+ success, error = execute_callback(
+ callback_path=path,
+ callback_kwargs=kwargs,
+ dag_rel_path=dag_rel_path,
+ bundle_path=bundle_path,
+ log=log,
+ )
assert success is expect_success
if error_contains:
assert error_contains in error
else:
assert error is None
+
+ def test_execute_callback_unusual_prefix_success(self, tmp_path):
+ """Test successful execution of callback with same Dag module path."""
+ dag_file = tmp_path / "test_dag.py"
+ dag_content = dedent('''
+ def test_callback(**kwargs):
+ """Test callback function."""
+ return "success"
+ ''')
+ dag_file.write_text(dag_content)
+
+ callback_path = "unusual_prefix_abc123_test_dag.test_callback"
+ callback_kwargs = {"param1": "value1", "context": {"dag_id": "test"}}
+ dag_rel_path = Path("test_dag.py")
+ bundle_path = tmp_path
+ log = structlog.get_logger()
+
+ success, error = execute_callback(
+ callback_path=callback_path,
+ callback_kwargs=callback_kwargs,
+ dag_rel_path=dag_rel_path,
+ bundle_path=bundle_path,
+ log=log,
+ )
+
+ assert success is True
+ assert error is None
+
+ @pytest.mark.parametrize(
+ ("dag_rel_path", "bundle_path", "expected_error"),
+ [
+ pytest.param(None, Path("bundle/path"), "Dag relative path not
found", id="missing_dag_path"),
+ pytest.param(Path("test.py"), None, "Bundle path not found",
id="missing_bundle_path"),
+ ],
+ )
+ def test_execute_callback_unusual_prefix_missing_paths(self, dag_rel_path,
bundle_path, expected_error):
+ """Test same Dag module callback with missing required paths."""
+ callback_path = "unusual_prefix_abc123_test_dag.test_callback"
+ callback_kwargs = {"param1": "value1"}
+ log = structlog.get_logger()
+
+ success, error = execute_callback(
+ callback_path=callback_path,
+ callback_kwargs=callback_kwargs,
+ dag_rel_path=dag_rel_path,
+ bundle_path=bundle_path,
+ log=log,
+ )
+
+ assert success is False
+ assert expected_error in error
diff --git a/airflow-core/tests/unit/executors/test_local_executor.py
b/airflow-core/tests/unit/executors/test_local_executor.py
index 70aaf2b5c56..c6c0aac3b50 100644
--- a/airflow-core/tests/unit/executors/test_local_executor.py
+++ b/airflow-core/tests/unit/executors/test_local_executor.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import gc
import multiprocessing
import os
+from pathlib import Path
from unittest import mock
import pytest
@@ -451,7 +452,7 @@ class TestLocalExecutorCallbackSupport:
)
callback_workload = workloads.ExecuteCallback(
callback=callback_data,
- dag_rel_path="test.py",
+ dag_rel_path=Path("test.py"),
bundle_info=BundleInfo(name="test_bundle", version="1.0"),
token="test_token",
log_path="test.log",
@@ -463,6 +464,7 @@ class TestLocalExecutorCallbackSupport:
id=self.CALLBACK_UUID,
callback_path="test.module.my_callback",
callback_kwargs={"arg1": "val1"},
+ dag_rel_path=Path("test.py"),
log_path="test.log",
bundle_info=BundleInfo(name="test_bundle", version="1.0"),
token=TestLocalExecutorCallbackSupport.TEST_TOKEN,
@@ -481,7 +483,7 @@ class TestLocalExecutorCallbackSupport:
)
callback_workload = workloads.ExecuteCallback(
callback=callback_data,
- dag_rel_path="test.py",
+ dag_rel_path=Path("test.py"),
bundle_info=BundleInfo(name="test_bundle", version="1.0"),
token="test_token",
log_path="test.log",
diff --git
a/shared/module_loading/src/airflow_shared/module_loading/__init__.py
b/shared/module_loading/src/airflow_shared/module_loading/__init__.py
index 1dea2c431e5..238506ae52d 100644
--- a/shared/module_loading/src/airflow_shared/module_loading/__init__.py
+++ b/shared/module_loading/src/airflow_shared/module_loading/__init__.py
@@ -27,6 +27,10 @@ from collections.abc import Callable, Iterator
from importlib import import_module
from typing import TYPE_CHECKING
+from .dag_file import (
+ MODIFIED_DAG_MODULE_NAME as MODIFIED_DAG_MODULE_NAME,
+ UNUSUAL_MODULE_PREFIX as UNUSUAL_MODULE_PREFIX,
+)
from .file_discovery import (
find_path_from_directory as find_path_from_directory,
)
diff --git
a/shared/module_loading/src/airflow_shared/module_loading/dag_file.py
b/shared/module_loading/src/airflow_shared/module_loading/dag_file.py
new file mode 100644
index 00000000000..d4fdc80c737
--- /dev/null
+++ b/shared/module_loading/src/airflow_shared/module_loading/dag_file.py
@@ -0,0 +1,23 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Dag file utilities for finding and loading Dag files."""
+
+from __future__ import annotations
+
+UNUSUAL_MODULE_PREFIX = "unusual_prefix_"
+MODIFIED_DAG_MODULE_NAME =
f"{UNUSUAL_MODULE_PREFIX}{{path_hash}}_{{module_name}}"
diff --git a/shared/module_loading/tests/module_loading/test_dag_file.py
b/shared/module_loading/tests/module_loading/test_dag_file.py
new file mode 100644
index 00000000000..c7319384a93
--- /dev/null
+++ b/shared/module_loading/tests/module_loading/test_dag_file.py
@@ -0,0 +1,26 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow_shared.module_loading import MODIFIED_DAG_MODULE_NAME,
UNUSUAL_MODULE_PREFIX
+
+
+def test_constants() -> None:
+ """Test that the constants are as expected."""
+ assert UNUSUAL_MODULE_PREFIX == "unusual_prefix_"
+ assert MODIFIED_DAG_MODULE_NAME ==
"unusual_prefix_{path_hash}_{module_name}"
diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py
index 7679c2328f9..090218ae10e 100644
--- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py
@@ -18,10 +18,13 @@
from __future__ import annotations
+import os
import signal
import sys
import time
from importlib import import_module
+from importlib.util import module_from_spec, spec_from_file_location
+from pathlib import Path
from typing import TYPE_CHECKING, Annotated, BinaryIO, ClassVar, Protocol
from uuid import UUID
@@ -29,7 +32,7 @@ import attrs
import structlog
from pydantic import Field, TypeAdapter
-from airflow.sdk._shared.module_loading import accepts_context,
accepts_keyword_args
+from airflow.sdk._shared.module_loading import UNUSUAL_MODULE_PREFIX,
accepts_context, accepts_keyword_args
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
ErrorResponse,
@@ -83,6 +86,8 @@ CallbackToSupervisor = Annotated[
def execute_callback(
callback_path: str,
callback_kwargs: dict,
+ dag_rel_path: os.PathLike[str],
+ bundle_path: os.PathLike[str] | None,
log,
) -> tuple[bool, str | None]:
"""
@@ -101,6 +106,8 @@ def execute_callback(
:param callback_path: Dot-separated import path to the callback function
or class.
:param callback_kwargs: Keyword arguments to pass to the callback.
+ :param dag_rel_path: Relative path to the DAG file.
+ :param bundle_path: Path to the bundle file.
:param log: Logger instance for recording execution.
:return: Tuple of (success: bool, error_message: str | None)
"""
@@ -111,7 +118,23 @@ def execute_callback(
# Import the callback callable
# Expected format: "module.path.to.function_or_class"
module_path, function_name = callback_path.rsplit(".", 1)
- module = import_module(module_path)
+ # If the callback is defined within the Dag module, the module path is
modified during DAG serialization.
+ # Attempt to import it using the path of the Dag file.
+ if module_path.startswith(UNUSUAL_MODULE_PREFIX):
+ if not dag_rel_path:
+ return False, "Dag relative path not found."
+ if not bundle_path:
+ return False, "Bundle path not found."
+ abs_path = Path(bundle_path) / Path(dag_rel_path)
+ spec = spec_from_file_location(module_path, abs_path)
+ if spec is None:
+ return False, f"Could not create module spec for {module_path}"
+ if spec.loader is None:
+ return False, f"Module spec has no loader for {module_path}"
+ module = module_from_spec(spec)
+ spec.loader.exec_module(module)
+ else:
+ module = import_module(module_path)
callback_callable = getattr(module, function_name)
log.debug("Executing callback", callback_path=callback_path,
callback_kwargs=callback_kwargs)
@@ -175,6 +198,7 @@ class CallbackSubprocess(WatchedSubprocess):
id: str,
callback_path: str,
callback_kwargs: dict,
+ dag_rel_path: os.PathLike[str],
bundle_info: _BundleInfoLike | None = None,
client: Client,
logger: FilteringBoundLogger | None = None,
@@ -191,6 +215,7 @@ class CallbackSubprocess(WatchedSubprocess):
_log = structlog.get_logger(logger_name="callback_runner")
task_runner.SUPERVISOR_COMMS = CommsDecoder[ToTask,
CallbackToSupervisor](log=_log)
+ bundle_path = None
# If bundle info is provided, initialize the bundle and ensure its
path is importable.
# This is needed for user-defined callbacks that live inside a DAG
bundle rather than
@@ -216,7 +241,13 @@ class CallbackSubprocess(WatchedSubprocess):
exc_info=True,
)
- success, error_msg = execute_callback(callback_path,
callback_kwargs, _log)
+ success, error_msg = execute_callback(
+ callback_path=callback_path,
+ callback_kwargs=callback_kwargs,
+ dag_rel_path=dag_rel_path,
+ bundle_path=bundle_path,
+ log=_log,
+ )
if not success:
_log.error("Callback failed", error=error_msg)
sys.exit(1)
@@ -349,6 +380,7 @@ def supervise_callback(
id: str,
callback_path: str,
callback_kwargs: dict,
+ dag_rel_path: os.PathLike[str],
log_path: str | None = None,
bundle_info: _BundleInfoLike | None = None,
token: str = "",
@@ -361,6 +393,7 @@ def supervise_callback(
:param id: Unique identifier for this callback execution.
:param callback_path: Dot-separated import path to the callback function
or class.
:param callback_kwargs: Keyword arguments to pass to the callback.
+ :param dag_rel_path: Relative path to the DAG file.
:param log_path: Path to write logs, if required.
:param bundle_info: When provided, the bundle's path is added to sys.path
so callbacks in Dag Bundles are importable.
:param token: Authentication token for the API client.
@@ -387,6 +420,7 @@ def supervise_callback(
id=id,
callback_path=callback_path,
callback_kwargs=callback_kwargs,
+ dag_rel_path=dag_rel_path,
bundle_info=bundle_info,
client=client,
logger=logger,
diff --git a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py
index b0f82cd37da..92249d2346b 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py
@@ -21,16 +21,18 @@ from __future__ import annotations
import signal
import socket
+import uuid
from dataclasses import dataclass
from operator import attrgetter
from typing import Any
-from unittest.mock import patch
+from unittest.mock import ANY, Mock, patch
import pytest
import structlog
-from airflow.sdk.execution_time.callback_supervisor import CallbackSubprocess,
execute_callback
+from airflow.sdk.execution_time.callback_supervisor import CallbackSubprocess,
Path, execute_callback
from airflow.sdk.execution_time.comms import (
+ BundleInfo,
ConnectionResult,
GetConnection,
GetVariable,
@@ -69,11 +71,13 @@ class CallableClass:
class TestExecuteCallback:
@pytest.mark.parametrize(
- ("path", "kwargs", "expect_success", "error_contains"),
+ ("path", "kwargs", "dag_rel_path", "bundle_path", "expect_success",
"error_contains"),
[
pytest.param(
f"{__name__}.callback_no_args",
{},
+ Path("test.py"),
+ Path("bundle/path"),
True,
None,
id="successful_no_args",
@@ -81,6 +85,8 @@ class TestExecuteCallback:
pytest.param(
f"{__name__}.callback_with_kwargs",
{"arg1": "hello", "arg2": "world"},
+ Path("test.py"),
+ Path("bundle/path"),
True,
None,
id="successful_with_kwargs",
@@ -88,6 +94,8 @@ class TestExecuteCallback:
pytest.param(
f"{__name__}.CallableClass",
{"msg": "alert"},
+ Path("test.py"),
+ Path("bundle/path"),
True,
None,
id="callable_class_pattern",
@@ -95,6 +103,8 @@ class TestExecuteCallback:
pytest.param(
"",
{},
+ Path("test.py"),
+ Path("bundle/path"),
False,
"Callback path not found",
id="empty_path",
@@ -102,6 +112,8 @@ class TestExecuteCallback:
pytest.param(
"nonexistent.module.function",
{},
+ Path("test.py"),
+ Path("bundle/path"),
False,
"ModuleNotFoundError",
id="import_error",
@@ -109,6 +121,8 @@ class TestExecuteCallback:
pytest.param(
f"{__name__}.callback_that_raises",
{},
+ Path("test.py"),
+ Path("bundle/path"),
False,
"ValueError",
id="execution_error",
@@ -116,15 +130,23 @@ class TestExecuteCallback:
pytest.param(
f"{__name__}.nonexistent_function_xyz",
{},
+ Path("test.py"),
+ Path("bundle/path"),
False,
"AttributeError",
id="attribute_error",
),
],
)
- def test_execute_callback(self, path, kwargs, expect_success,
error_contains):
+ def test_execute_callback(self, path, kwargs, dag_rel_path, bundle_path,
expect_success, error_contains):
log = structlog.get_logger()
- success, error = execute_callback(path, kwargs, log)
+ success, error = execute_callback(
+ callback_path=path,
+ callback_kwargs=kwargs,
+ dag_rel_path=dag_rel_path,
+ bundle_path=bundle_path,
+ log=log,
+ )
assert success is expect_success
if error_contains:
@@ -305,3 +327,128 @@ class TestCallbackExecutionTimeout:
proc._monitor_subprocess()
mock_kill.assert_called_once_with(proc, signal.SIGTERM,
escalation_delay=5.0, force=True)
+
+
+class TestCallbackSubprocessStart:
+ """Verify that CallbackSubprocess.start() properly initializes and
executes the callback target."""
+
+ @pytest.fixture
+ def mock_client(self):
+ """Mock HTTP client."""
+ return Mock()
+
+ @pytest.fixture
+ def base_start_kwargs(self, mock_client):
+ """Base kwargs."""
+ return {
+ "id": str(uuid.uuid4()),
+ "callback_path": "my_module.my_callback_function",
+ "callback_kwargs": {"param1": "value1", "param2": 1},
+ "dag_rel_path": Path("dags/my_test_dag.py"),
+ "bundle_info": None,
+ "client": mock_client,
+ }
+
+ @pytest.fixture(autouse=True)
+ def base_mocks_setup(self, mock_supervisor_comms):
+ """Base mocks for all tests in this class."""
+ with (
+ patch("airflow.sdk.execution_time.comms.CommsDecoder") as
mock_comms,
+
patch("airflow.sdk.execution_time.callback_supervisor.WatchedSubprocess.start")
as mock_super,
+
patch("airflow.sdk.execution_time.callback_supervisor.execute_callback") as
mock_execute,
+ ):
+ mock_execute.return_value = (True, None)
+
+ self.mock_comms_decoder = mock_comms
+ self.mock_super_start = mock_super
+ self.mock_execute_callback = mock_execute
+ yield
+
+ @pytest.fixture
+ def mock_bundle_setup(self):
+ """Setup bundle-related mocks."""
+ with patch("airflow.dag_processing.bundles.manager.DagBundlesManager")
as mock_manager_class:
+ mock_bundle = Mock()
+ bundle_path = Path("/path/to/bundle")
+ mock_bundle.path = bundle_path
+ mock_bundle.name = "test-bundle"
+
+ mock_bundle_manager = Mock()
+ mock_manager_class.return_value = mock_bundle_manager
+ mock_bundle_manager.get_bundle.return_value = mock_bundle
+
+ yield {
+ "manager_class": mock_manager_class,
+ "manager": mock_bundle_manager,
+ "bundle": mock_bundle,
+ "bundle_path": bundle_path,
+ }
+
+ def test_execute_callback_receives_correct_parameters(self,
base_start_kwargs):
+ """Test that execute_callback receives the correct parameters."""
+ CallbackSubprocess.start(**base_start_kwargs)
+ self.mock_super_start.call_args.kwargs["target"]()
+
+ self.mock_super_start.assert_called_with(
+ id=uuid.UUID(base_start_kwargs["id"]),
client=base_start_kwargs["client"], target=ANY, logger=None
+ )
+
+ self.mock_execute_callback.assert_called_with(
+ bundle_path=None,
+ callback_kwargs=base_start_kwargs["callback_kwargs"],
+ callback_path=base_start_kwargs["callback_path"],
+ dag_rel_path=base_start_kwargs["dag_rel_path"],
+ log=ANY,
+ )
+
+ def test_execute_callback_with_bundle_info_should_pass_correct_parameters(
+ self, base_start_kwargs, mock_bundle_setup
+ ):
+ """Test that execute_callback receives the correct parameters when
bundle_info is provided."""
+ bundle_info = BundleInfo(name="test-bundle", version="1.0")
+ adjusted_kwargs = {**base_start_kwargs, "bundle_info": bundle_info}
+
+ CallbackSubprocess.start(**adjusted_kwargs)
+ self.mock_super_start.call_args.kwargs["target"]()
+
+ self.mock_super_start.assert_called_with(
+ id=uuid.UUID(adjusted_kwargs["id"]),
client=adjusted_kwargs["client"], target=ANY, logger=None
+ )
+
+ mock_bundle_setup["manager"].get_bundle.assert_called_once_with(
+ name=bundle_info.name,
+ version=bundle_info.version,
+ )
+ mock_bundle_setup["bundle"].initialize.assert_called_once()
+
+ self.mock_execute_callback.assert_called_with(
+ bundle_path=str(mock_bundle_setup["bundle_path"]),
+ callback_kwargs=adjusted_kwargs["callback_kwargs"],
+ callback_path=adjusted_kwargs["callback_path"],
+ dag_rel_path=adjusted_kwargs["dag_rel_path"],
+ log=ANY,
+ )
+
+ def test_callback_supervisor_with_bundle_info_should_adjust_sys_path(
+ self, base_start_kwargs, mock_bundle_setup
+ ):
+ """Test that bundle_path is added to sys.path when bundle path is
provided."""
+ with patch("sys.path", new_callable=list) as mock_sys_path:
+ bundle_info = BundleInfo(name="test-bundle", version="1.0")
+ adjusted_kwargs = {**base_start_kwargs, "bundle_info": bundle_info}
+
+ CallbackSubprocess.start(**adjusted_kwargs)
+ self.mock_super_start.call_args.kwargs["target"]()
+
+ assert str(mock_bundle_setup["bundle_path"]) in mock_sys_path
+
+ def test_callback_supervisor_should_exit_on_error(self, base_start_kwargs):
+ """Test that callback supervisor exits if execute_callback returns an
error."""
+ self.mock_execute_callback.return_value = (False, "Some error
occurred")
+
+ CallbackSubprocess.start(**base_start_kwargs)
+
+ with pytest.raises(SystemExit) as exc_info:
+ self.mock_super_start.call_args.kwargs["target"]()
+
+ assert exc_info.value.code == 1