This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c03460ed9dd common.ai: Park approval reviews in awaiting_input on
Airflow 3.3+ (#68489)
c03460ed9dd is described below
commit c03460ed9ddf342469dccc39f7094ddae6c72226
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jun 16 13:01:25 2026 +0100
common.ai: Park approval reviews in awaiting_input on Airflow 3.3+ (#68489)
LLMApprovalMixin (require_approval=True on LLMOperator/AgentOperator) now
raises TaskAwaitingInput on Airflow 3.3+ so the task parks in the
first-class awaiting_input state -- no trigger or triggerer involved --
matching the standard provider's HITLOperator. On older cores it falls
back to deferring to HITLTrigger as before. The response deadline is
enforced by the scheduler's awaiting_input timeout sweep on 3.3+.
Because nothing upstream schema-validates params_input on the
awaiting_input path (HITLTrigger did on the legacy path),
execute_complete now enforces the string contract for reviewer-modified
output and raises HITLTriggerEventError for non-string values.
The AIRFLOW_V_3_3_PLUS flag this uses was added in
apache-airflow-providers-common-compat 1.15.0; the dependency line is
marked "# use next version" so the release manager bumps the floor at
release time.
---
providers/common/ai/pyproject.toml | 2 +-
.../airflow/providers/common/ai/mixins/approval.py | 39 ++++++++++++--
.../tests/unit/common/ai/mixins/test_approval.py | 62 +++++++++++++++++++++-
.../ai/tests/unit/common/ai/operators/test_llm.py | 55 ++++++++++++++-----
.../common/ai/operators/test_llm_file_analysis.py | 23 ++++----
.../tests/unit/common/ai/operators/test_llm_sql.py | 18 +++++--
6 files changed, 165 insertions(+), 34 deletions(-)
diff --git a/providers/common/ai/pyproject.toml
b/providers/common/ai/pyproject.toml
index db08fab7374..e569ae88f91 100644
--- a/providers/common/ai/pyproject.toml
+++ b/providers/common/ai/pyproject.toml
@@ -67,7 +67,7 @@ requires-python = ">=3.10"
# After you modify the dependencies, and rebuild your Breeze CI image with
``breeze ci-image build``
dependencies = [
"apache-airflow>=3.0.0",
- "apache-airflow-providers-common-compat>=1.14.1",
+ "apache-airflow-providers-common-compat>=1.14.1", # use next version
"apache-airflow-providers-standard>=1.12.1",
"pydantic-ai-slim>=1.99.0",
]
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py
b/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py
index 07855340c4b..5ebd679efcd 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py
@@ -23,6 +23,13 @@ from typing import TYPE_CHECKING, Any, Protocol
from pydantic import BaseModel
+from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_3_PLUS
+
+if AIRFLOW_V_3_3_PLUS:
+ # On Airflow 3.3+ the review parks the task in the first-class
AWAITING_INPUT state instead
+ # of deferring to a trigger. On older cores this name is absent and
defer() is used.
+ from airflow.sdk.exceptions import TaskAwaitingInput
+
log = logging.getLogger(__name__)
if TYPE_CHECKING:
@@ -45,7 +52,8 @@ class LLMApprovalMixin:
When ``require_approval=True`` on the operator, the generated output is
presented to a human reviewer via the Airflow Human-in-the-Loop (HITL)
- interface. The task defers until the reviewer approves or rejects.
+ interface. The task waits (``awaiting_input`` on Airflow 3.3+, deferred on
+ older versions) until the reviewer approves or rejects.
If ``allow_modifications=True``, the reviewer can also edit the output
before approving. The (possibly modified) output is then returned as the
@@ -71,7 +79,11 @@ class LLMApprovalMixin:
body: str | None = None,
) -> None:
"""
- Write HITL detail, then defer to HITLTrigger for human review.
+ Write HITL detail, then pause the task for human review.
+
+ On Airflow 3.3+ the task parks in the ``awaiting_input`` state (no
trigger or triggerer
+ involved); on older versions it defers to :class:`HITLTrigger`. Either
way it resumes in
+ ``execute_complete`` once a response (or timeout default) arrives.
:param context: Airflow task context.
:param output: The generated output to present for review.
@@ -100,7 +112,6 @@ class LLMApprovalMixin:
output = str(output)
ti_id = context["task_instance"].id
- timeout_datetime = utcnow() + self.approval_timeout if
self.approval_timeout else None
if subject is None:
subject = f"Review output for task `{self.task_id}`"
@@ -128,6 +139,16 @@ class LLMApprovalMixin:
params=hitl_params,
)
+ if AIRFLOW_V_3_3_PLUS:
+ # New core (3.3+): park the task in AWAITING_INPUT -- no trigger,
no triggerer. The
+ # task is resumed by the Core API response handler or the
scheduler timeout sweep.
+ raise TaskAwaitingInput(
+ method_name="execute_complete",
+ kwargs={"generated_output": output},
+ timeout=self.approval_timeout,
+ )
+
+ # Fallback for cores < 3.3: defer the response check to HITLTrigger on
the triggerer.
self.defer(
trigger=HITLTrigger(
ti_id=ti_id,
@@ -135,7 +156,7 @@ class LLMApprovalMixin:
defaults=None,
params=hitl_params,
multiple=False,
- timeout_datetime=timeout_datetime,
+ timeout_datetime=utcnow() + self.approval_timeout if
self.approval_timeout else None,
),
method_name="execute_complete",
kwargs={"generated_output": output},
@@ -182,6 +203,16 @@ class LLMApprovalMixin:
# when allow_modifications=False, bypassing the read-only approval
flow.
if getattr(self, "allow_modifications", False) and params_input:
modified = params_input.get("output")
+ if modified is not None and not isinstance(modified, str):
+ # On the awaiting_input path nothing upstream schema-validates
params_input
+ # (HITLTrigger did on the legacy path), so enforce the string
contract here
+ # rather than returning a non-string as the task's output.
+ raise HITLTriggerEventError(
+ {
+ "error": f"Modified output must be a string, got
{type(modified).__name__}.",
+ "error_type": "validation",
+ }
+ )
if modified is not None and modified != generated_output:
log.info("output=%s modified by the reviewer=%s ", modified,
responded_by_user)
return modified
diff --git a/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py
b/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py
index 464dfe38986..54b675da723 100644
--- a/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py
+++ b/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import pytest
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS,
AIRFLOW_V_3_3_PLUS
if not AIRFLOW_V_3_1_PLUS:
pytest.skip("Human in the loop is only compatible with Airflow >= 3.1.0",
allow_module_level=True)
@@ -34,9 +34,13 @@ from airflow.providers.common.ai.mixins.approval import (
)
from airflow.providers.standard.exceptions import HITLRejectException,
HITLTriggerEventError
+if AIRFLOW_V_3_3_PLUS:
+ from airflow.sdk.exceptions import TaskAwaitingInput
+
HITL_TRIGGER_PATH = "airflow.providers.standard.triggers.hitl.HITLTrigger"
UPSERT_HITL_PATH = "airflow.sdk.execution_time.hitl.upsert_hitl_detail"
UTCNOW_PATH = "airflow.sdk.timezone.utcnow"
+AWAIT_INPUT_FLAG_PATH =
"airflow.providers.common.ai.mixins.approval.AIRFLOW_V_3_3_PLUS"
class FakeOperator(LLMApprovalMixin):
@@ -76,6 +80,9 @@ def context():
return MagicMock(**{"__getitem__": lambda self, key: {"task_instance":
ti}[key]})
+# The legacy trigger path is taken on cores < 3.3; pin the flag so these tests
keep
+# exercising the defer() fallback when run against newer cores.
+@patch(AWAIT_INPUT_FLAG_PATH, False)
class TestDeferForApproval:
@patch(HITL_TRIGGER_PATH, autospec=True)
@patch(UPSERT_HITL_PATH)
@@ -253,6 +260,21 @@ class TestDeferForApproval:
assert result == "modified output"
+ def test_approved_with_non_string_modified_output_raises(self,
approval_op_with_modifications):
+ # On the awaiting_input path nothing upstream schema-validates
params_input
+ # (HITLTrigger did on the legacy path), so execute_complete must
enforce the
+ # string contract instead of returning a dict as the task's output.
+ event = {
+ "chosen_options": ["Approve"],
+ "responded_by_user": "editor",
+ "params_input": {"output": {"sneaky": "dict"}},
+ }
+
+ with pytest.raises(HITLTriggerEventError, match="must be a string"):
+ approval_op_with_modifications.execute_complete(
+ {}, generated_output="original output", event=event
+ )
+
def test_approved_with_unmodified_output(self,
approval_op_with_modifications):
event = {
"chosen_options": ["Approve"],
@@ -324,3 +346,41 @@ class TestDeferForApproval:
with pytest.raises(HITLRejectException, match="alice"):
approval_op.execute_complete({}, generated_output="output",
event=event)
+
+
[email protected](not AIRFLOW_V_3_3_PLUS, reason="awaiting_input path
requires Airflow 3.3+")
+class TestAwaitInputForApproval:
+ """On Airflow 3.3+ the review parks the task in AWAITING_INPUT instead of
deferring."""
+
+ @patch(UPSERT_HITL_PATH)
+ def test_parks_task_in_awaiting_input(self, mock_upsert, approval_op,
context):
+ with pytest.raises(TaskAwaitingInput) as exc_info:
+ approval_op.defer_for_approval(context, "some LLM output")
+
+ assert exc_info.value.method_name == "execute_complete"
+ assert exc_info.value.kwargs == {"generated_output": "some LLM output"}
+ assert exc_info.value.timeout is None
+ mock_upsert.assert_called_once()
+ assert mock_upsert.call_args[1]["options"] == ["Approve", "Reject"]
+ approval_op.defer.assert_not_called()
+
+ @patch(UPSERT_HITL_PATH)
+ def test_approval_timeout_carried_on_await(self, mock_upsert, context):
+ timeout = timedelta(hours=2)
+ op = FakeOperator(approval_timeout=timeout)
+
+ with pytest.raises(TaskAwaitingInput) as exc_info:
+ op.defer_for_approval(context, "output")
+
+ assert exc_info.value.timeout == timeout
+
+ @patch(UPSERT_HITL_PATH)
+ def test_pydantic_output_stringified_on_await(self, mock_upsert,
approval_op, context):
+ class Answer(BaseModel):
+ text: str
+ confidence: float
+
+ with pytest.raises(TaskAwaitingInput) as exc_info:
+ approval_op.defer_for_approval(context, Answer(text="Paris",
confidence=0.95))
+
+ assert exc_info.value.kwargs == {"generated_output":
'{"text":"Paris","confidence":0.95}'}
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
index d5ef8228d35..f9f3bf09099 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
@@ -29,13 +29,25 @@ from airflow.providers.common.ai.mixins.approval import (
)
from airflow.providers.common.ai.operators.llm import LLMOperator
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS,
AIRFLOW_V_3_3_PLUS
try:
from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as
_CORE_WALKER
except ImportError:
_CORE_WALKER = False
+from airflow.providers.common.compat.sdk import TaskDeferred
+
+if AIRFLOW_V_3_3_PLUS:
+ # On 3.3+ cores require_approval pauses the task in AWAITING_INPUT; older
cores defer
+ # to HITLTrigger. Both exceptions carry method_name/kwargs/timeout, so the
approval
+ # tests assert against whichever pause signal the running core uses.
+ from airflow.sdk.exceptions import TaskAwaitingInput as ApprovalPauseSignal
+else:
+ ApprovalPauseSignal = TaskDeferred # type: ignore[assignment, misc]
+
+AWAIT_INPUT_FLAG_PATH =
"airflow.providers.common.ai.mixins.approval.AIRFLOW_V_3_3_PLUS"
+
# Returning the Pydantic instance through XCom (rather than a dict) only
happens
# on cores that register declared ``output_type`` classes from the worker-side
# DAG walk. On older cores the operator dumps to a dict, so these tests skip.
@@ -187,8 +199,6 @@ class TestLLMOperatorApproval:
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_with_approval_defers(self, mock_hook_cls, mock_upsert,
mock_trigger_cls):
"""When require_approval=True, execute() defers instead of returning
output."""
- from airflow.providers.common.compat.sdk import TaskDeferred
-
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("LLM
response")
mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
@@ -201,20 +211,43 @@ class TestLLMOperatorApproval:
)
ctx = _make_context()
- with pytest.raises(TaskDeferred) as exc_info:
+ with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=ctx)
assert exc_info.value.method_name == "execute_complete"
assert exc_info.value.kwargs["generated_output"] == "LLM response"
mock_upsert.assert_called_once()
+ @patch(AWAIT_INPUT_FLAG_PATH, False)
+ @patch("airflow.providers.standard.triggers.hitl.HITLTrigger",
autospec=True)
+ @patch("airflow.sdk.execution_time.hitl.upsert_hitl_detail")
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def test_execute_with_approval_defers_on_legacy_core(self, mock_hook_cls,
mock_upsert, mock_trigger_cls):
+ """On cores < 3.3 (flag pinned), execute() falls back to deferring to
HITLTrigger."""
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_agent.run_sync.return_value = _make_mock_run_result("LLM
response")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ op = LLMOperator(
+ task_id="legacy_approval_test",
+ prompt="Summarize this",
+ llm_conn_id="my_llm",
+ require_approval=True,
+ )
+
+ with pytest.raises(TaskDeferred) as exc_info:
+ op.execute(context=_make_context())
+
+ assert exc_info.value.method_name == "execute_complete"
+ assert exc_info.value.kwargs["generated_output"] == "LLM response"
+ mock_trigger_cls.assert_called_once()
+ mock_upsert.assert_called_once()
+
@patch("airflow.providers.standard.triggers.hitl.HITLTrigger",
autospec=True)
@patch("airflow.sdk.execution_time.hitl.upsert_hitl_detail")
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_with_approval_and_modifications(self, mock_hook_cls,
mock_upsert, mock_trigger_cls):
"""allow_modifications=True passes an editable 'output' param."""
- from airflow.providers.common.compat.sdk import TaskDeferred
-
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("draft
output")
mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
@@ -228,7 +261,7 @@ class TestLLMOperatorApproval:
)
ctx = _make_context()
- with pytest.raises(TaskDeferred):
+ with pytest.raises(ApprovalPauseSignal):
op.execute(context=ctx)
upsert_kwargs = mock_upsert.call_args[1]
@@ -239,8 +272,6 @@ class TestLLMOperatorApproval:
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_with_approval_and_timeout(self, mock_hook_cls,
mock_upsert, mock_trigger_cls):
"""approval_timeout is passed to the trigger."""
- from airflow.providers.common.compat.sdk import TaskDeferred
-
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("output")
mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
@@ -255,7 +286,7 @@ class TestLLMOperatorApproval:
)
ctx = _make_context()
- with pytest.raises(TaskDeferred) as exc_info:
+ with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=ctx)
assert exc_info.value.timeout == timeout
@@ -265,8 +296,6 @@ class TestLLMOperatorApproval:
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_with_approval_structured_output(self, mock_hook_cls,
mock_upsert, mock_trigger_cls):
"""Structured (BaseModel) output is serialized before deferring."""
- from airflow.providers.common.compat.sdk import TaskDeferred
-
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value =
_make_mock_run_result(Summary(text="hello"))
mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
@@ -280,7 +309,7 @@ class TestLLMOperatorApproval:
)
ctx = _make_context()
- with pytest.raises(TaskDeferred) as exc_info:
+ with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=ctx)
assert exc_info.value.kwargs["generated_output"] == '{"text":"hello"}'
diff --git
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
index 7c955a160b4..9e692b420f9 100644
---
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
+++
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
@@ -25,8 +25,17 @@ from pydantic import BaseModel
from airflow.providers.common.ai.operators.llm_file_analysis import
LLMFileAnalysisOperator
from airflow.providers.common.ai.utils.file_analysis import FileAnalysisRequest
+from airflow.providers.common.compat.sdk import TaskDeferred
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS,
AIRFLOW_V_3_3_PLUS
+
+if AIRFLOW_V_3_3_PLUS:
+ # On 3.3+ cores require_approval pauses the task in AWAITING_INPUT; older
cores defer to
+ # HITLTrigger. Both signals carry method_name/kwargs/timeout, so the
approval tests assert
+ # against whichever pause signal the running core uses.
+ from airflow.sdk.exceptions import TaskAwaitingInput as ApprovalPauseSignal
+else:
+ ApprovalPauseSignal = TaskDeferred # type: ignore[assignment, misc]
try:
from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as
_CORE_WALKER
@@ -208,8 +217,6 @@ class TestLLMFileAnalysisOperatorApproval:
def test_execute_with_approval_defers(
self, mock_build_request, mock_hook_cls, mock_upsert, mock_trigger_cls
):
- from airflow.providers.common.compat.sdk import TaskDeferred
-
mock_build_request.return_value = FileAnalysisRequest(
user_content="prepared prompt",
resolved_paths=["/tmp/app.log"],
@@ -228,7 +235,7 @@ class TestLLMFileAnalysisOperatorApproval:
)
ctx = _make_context()
- with pytest.raises(TaskDeferred) as exc_info:
+ with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=ctx)
assert exc_info.value.method_name == "execute_complete"
@@ -244,8 +251,6 @@ class TestLLMFileAnalysisOperatorApproval:
def test_execute_with_approval_defers_structured_output_as_json(
self, mock_build_request, mock_hook_cls, mock_upsert, mock_trigger_cls
):
- from airflow.providers.common.compat.sdk import TaskDeferred
-
mock_build_request.return_value = FileAnalysisRequest(
user_content="prepared prompt",
resolved_paths=["/tmp/app.log"],
@@ -264,7 +269,7 @@ class TestLLMFileAnalysisOperatorApproval:
require_approval=True,
)
- with pytest.raises(TaskDeferred) as exc_info:
+ with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=_make_context())
assert exc_info.value.kwargs["generated_output"] ==
'{"findings":["error spike"]}'
@@ -318,8 +323,6 @@ class TestLLMFileAnalysisOperatorApproval:
def test_execute_with_approval_timeout(
self, mock_build_request, mock_hook_cls, mock_upsert, mock_trigger_cls
):
- from airflow.providers.common.compat.sdk import TaskDeferred
-
mock_build_request.return_value = FileAnalysisRequest(
user_content="prepared prompt",
resolved_paths=["/tmp/app.log"],
@@ -339,7 +342,7 @@ class TestLLMFileAnalysisOperatorApproval:
approval_timeout=timeout,
)
- with pytest.raises(TaskDeferred) as exc_info:
+ with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=_make_context())
assert exc_info.value.timeout == timeout
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
index a994ae3d1cd..1862971c953 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
@@ -30,7 +30,15 @@ from airflow.providers.common.ai.utils.sql_validation import
SQLSafetyError
from airflow.providers.common.compat.sdk import TaskDeferred
from airflow.providers.common.sql.config import DataSourceConfig
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS,
AIRFLOW_V_3_3_PLUS
+
+if AIRFLOW_V_3_3_PLUS:
+ # On 3.3+ cores require_approval pauses the task in AWAITING_INPUT; older
cores defer to
+ # HITLTrigger. Both signals carry method_name/kwargs/timeout, so the
approval tests assert
+ # against whichever pause signal the running core uses.
+ from airflow.sdk.exceptions import TaskAwaitingInput as ApprovalPauseSignal
+else:
+ ApprovalPauseSignal = TaskDeferred # type: ignore[assignment, misc]
def _make_mock_run_result(output):
@@ -475,7 +483,7 @@ class TestLLMSQLQueryOperatorApproval:
)
ctx = _make_context()
- with pytest.raises(TaskDeferred) as exc_info:
+ with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=ctx)
assert exc_info.value.method_name == "execute_complete"
@@ -521,7 +529,7 @@ class TestLLMSQLQueryOperatorApproval:
)
ctx = _make_context()
- with pytest.raises(TaskDeferred):
+ with pytest.raises(ApprovalPauseSignal):
op.execute(context=ctx)
upsert_kwargs = mock_upsert.call_args[1]
@@ -545,7 +553,7 @@ class TestLLMSQLQueryOperatorApproval:
)
ctx = _make_context()
- with pytest.raises(TaskDeferred) as exc_info:
+ with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=ctx)
assert exc_info.value.timeout == timeout
@@ -583,7 +591,7 @@ class TestLLMSQLQueryOperatorApproval:
)
ctx = _make_context()
- with pytest.raises(TaskDeferred) as exc_info:
+ with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=ctx)
assert exc_info.value.kwargs["generated_output"] == "SELECT 1"