This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 325f3774ba6 Accept Sequence[UserContent] in common.ai TaskFlow
decorators (#67389)
325f3774ba6 is described below
commit 325f3774ba6dcde31809555a12050b048d79fbc6
Author: Kaxil Naik <[email protected]>
AuthorDate: Sun May 24 21:33:33 2026 +0100
Accept Sequence[UserContent] in common.ai TaskFlow decorators (#67389)
* Accept Sequence[UserContent] prompts in common.ai TaskFlow decorators
@task.agent, @task.llm, @task.llm_branch, @task.llm_schema_compare and
@task.llm_sql decorators now accept a Sequence of pydantic-ai UserContent
items (ImageUrl, AudioUrl, DocumentUrl, etc.) in addition to str, mirroring
Agent.run_sync's input contract. This enables vision, audio, and document
inputs to pydantic-ai agents directly through the TaskFlow decorator path.
Sequence prompts fail loudly before any LLM call when combined with
enable_hitl_review=True (agent) or require_approval=True (llm, llm_sql) --
the HITL session model and approval review body both assume str prompts.
Both are tracked as follow-ups on the AIP-99 board.
* Drop manually-authored 0.4.0 changelog block
The provider changelog is regenerated by the release manager from git log
at wave time; manually authoring a versioned block pre-empts that and
duplicates the auto-extraction from the commit title. The HITL/approval
limitations are already documented in the operator docs (agent.rst,
llm.rst) where they belong.
* Add 'stringify' to global spelling_wordlist
The verb form of 'stringified' is used in the new validate_prompt /
reject_sequence_with_unsupported_feature docstring; only the past-tense
forms were in the wordlist. Sphinx spellcheck failed on the docstring
during build-docs.
---
docs/spelling_wordlist.txt | 1 +
providers/common/ai/docs/operators/agent.rst | 28 +++++
providers/common/ai/docs/operators/llm.rst | 12 ++
providers/common/ai/docs/operators/llm_branch.rst | 4 +
.../ai/docs/operators/llm_schema_compare.rst | 4 +
providers/common/ai/docs/operators/llm_sql.rst | 4 +
.../providers/common/ai/decorators/agent.py | 13 ++-
.../airflow/providers/common/ai/decorators/llm.py | 13 ++-
.../providers/common/ai/decorators/llm_branch.py | 6 +-
.../common/ai/decorators/llm_file_analysis.py | 5 +
.../common/ai/decorators/llm_schema_compare.py | 6 +-
.../providers/common/ai/decorators/llm_sql.py | 13 ++-
.../airflow/providers/common/ai/mixins/approval.py | 9 ++
.../airflow/providers/common/ai/operators/agent.py | 8 ++
.../airflow/providers/common/ai/operators/llm.py | 8 ++
.../providers/common/ai/operators/llm_sql.py | 8 ++
.../providers/common/ai/utils/validation.py | 91 +++++++++++++++
.../tests/unit/common/ai/decorators/test_agent.py | 50 +++++++-
.../ai/tests/unit/common/ai/decorators/test_llm.py | 45 ++++++-
.../unit/common/ai/decorators/test_llm_branch.py | 37 +++++-
.../ai/decorators/test_llm_schema_compare.py | 35 +++++-
.../unit/common/ai/decorators/test_llm_sql.py | 45 ++++++-
.../tests/unit/common/ai/mixins/test_approval.py | 5 +
.../tests/unit/common/ai/operators/test_agent.py | 27 +++++
.../ai/tests/unit/common/ai/operators/test_llm.py | 27 +++++
.../tests/unit/common/ai/operators/test_llm_sql.py | 26 +++++
.../tests/unit/common/ai/utils/test_validation.py | 129 +++++++++++++++++++++
27 files changed, 625 insertions(+), 34 deletions(-)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index f98d6edde30..b0fcaca2bfd 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1557,6 +1557,7 @@ Streamable
strftime
Stringified
stringified
+stringify
Struct
STS
subchart
diff --git a/providers/common/ai/docs/operators/agent.rst
b/providers/common/ai/docs/operators/agent.rst
index fc5efa27780..d58a276caef 100644
--- a/providers/common/ai/docs/operators/agent.rst
+++ b/providers/common/ai/docs/operators/agent.rst
@@ -87,6 +87,34 @@ the prompt string; all other parameters are passed to the
operator.
:end-before: [END howto_decorator_agent]
+.. _howto/operator:agent-multimodal:
+
+Multimodal prompts
+^^^^^^^^^^^^^^^^^^
+
+The decorated callable may also return a ``Sequence[UserContent]`` -- for
+example, a list mixing strings with ``ImageUrl``, ``BinaryContent``, or other
+pydantic-ai user-content types -- to send vision, audio, or document inputs
+to the model. This mirrors the input types accepted by pydantic-ai's
+``Agent.run_sync``.
+
+.. code-block:: python
+
+ from pydantic_ai.messages import ImageUrl
+
+
+ @task.agent(llm_conn_id="pydanticai_default", system_prompt="You are an
image analyst.")
+ def analyze_review(image_url: str):
+ return ["Describe what you see:", ImageUrl(url=image_url)]
+
+.. note::
+
+ Combining a non-string prompt with ``enable_hitl_review=True`` is not
+ currently supported -- the HITL session model stores the prompt as a
+ string, so a ``Sequence`` prompt will raise at the review boundary.
+ Widening HITL review to multimodal prompts is tracked as a follow-up.
+
+
Structured Output
-----------------
diff --git a/providers/common/ai/docs/operators/llm.rst
b/providers/common/ai/docs/operators/llm.rst
index a69a592eff8..1d1a2482710 100644
--- a/providers/common/ai/docs/operators/llm.rst
+++ b/providers/common/ai/docs/operators/llm.rst
@@ -114,6 +114,18 @@ With structured output:
:start-after: [START howto_decorator_llm_structured]
:end-before: [END howto_decorator_llm_structured]
+Multimodal prompts
+^^^^^^^^^^^^^^^^^^
+
+``@task.llm`` accepts the same prompt shape as ``@task.agent`` -- the callable
+may return either a ``str`` or a non-empty ``Sequence[UserContent]`` (e.g.,
+``["Describe this:", ImageUrl(url="...")]``) for vision, audio, or document
+inputs. See :ref:`@task.agent multimodal prompts
<howto/operator:agent-multimodal>` for
+the full example. ``require_approval=True`` is not currently supported with a
+``Sequence`` prompt -- the approval session model expects a string -- and will
+raise at the approval boundary; widening that path is tracked as a follow-up.
+
+
Classification with ``Literal``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/providers/common/ai/docs/operators/llm_branch.rst
b/providers/common/ai/docs/operators/llm_branch.rst
index 94e1ce16fc8..3c9e6490823 100644
--- a/providers/common/ai/docs/operators/llm_branch.rst
+++ b/providers/common/ai/docs/operators/llm_branch.rst
@@ -64,6 +64,10 @@ returns the prompt string; all other parameters are passed
to the operator:
:start-after: [START howto_decorator_llm_branch]
:end-before: [END howto_decorator_llm_branch]
+The callable may also return a non-empty ``Sequence[UserContent]`` for
+multimodal inputs -- see
+:ref:`@task.agent multimodal prompts <howto/operator:agent-multimodal>`.
+
With multiple branches:
.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
diff --git a/providers/common/ai/docs/operators/llm_schema_compare.rst
b/providers/common/ai/docs/operators/llm_schema_compare.rst
index ad548c7fc6e..d5f014adf0c 100644
--- a/providers/common/ai/docs/operators/llm_schema_compare.rst
+++ b/providers/common/ai/docs/operators/llm_schema_compare.rst
@@ -110,6 +110,10 @@ structured output:
:start-after: [START howto_decorator_llm_schema_compare]
:end-before: [END howto_decorator_llm_schema_compare]
+The callable may also return a non-empty ``Sequence[UserContent]`` for
+multimodal inputs -- see
+:ref:`@task.agent multimodal prompts <howto/operator:agent-multimodal>`.
+
Conditional ETL Based on Schema Compatibility
----------------------------------------------
diff --git a/providers/common/ai/docs/operators/llm_sql.rst
b/providers/common/ai/docs/operators/llm_sql.rst
index cdcc6c4a276..807fe5197e8 100644
--- a/providers/common/ai/docs/operators/llm_sql.rst
+++ b/providers/common/ai/docs/operators/llm_sql.rst
@@ -110,6 +110,10 @@ and safety validation:
:start-after: [START howto_decorator_llm_sql]
:end-before: [END howto_decorator_llm_sql]
+The callable may also return a non-empty ``Sequence[UserContent]`` for
+multimodal inputs -- see
+:ref:`@task.agent multimodal prompts <howto/operator:agent-multimodal>`.
+
Dynamic Task Mapping
--------------------
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
b/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
index 379c8cb65fa..25272d06009 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
@@ -28,6 +28,10 @@ from collections.abc import Callable, Collection, Mapping,
Sequence
from typing import TYPE_CHECKING, Any, ClassVar
from airflow.providers.common.ai.operators.agent import AgentOperator
+from airflow.providers.common.ai.utils.validation import (
+ reject_sequence_with_unsupported_feature,
+ validate_prompt,
+)
from airflow.providers.common.compat.sdk import (
DecoratedOperator,
TaskDecorator,
@@ -86,8 +90,13 @@ class _AgentDecoratedOperator(DecoratedOperator,
AgentOperator):
self.prompt = self.python_callable(*self.op_args, **kwargs)
- if not isinstance(self.prompt, str) or not self.prompt.strip():
- raise TypeError("The returned value from the @task.agent callable
must be a non-empty string.")
+ validate_prompt(self.prompt, decorator_name="@task.agent")
+ reject_sequence_with_unsupported_feature(
+ self.prompt,
+ decorator_name="@task.agent",
+ feature_name="enable_hitl_review",
+ feature_enabled=self.enable_hitl_review,
+ )
self.render_template_fields(context)
return AgentOperator.execute(self, context)
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/decorators/llm.py
b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm.py
index f21bcd343c0..f2db8628c1e 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/decorators/llm.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm.py
@@ -29,6 +29,10 @@ from collections.abc import Callable, Collection, Mapping,
Sequence
from typing import TYPE_CHECKING, Any, ClassVar
from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.ai.utils.validation import (
+ reject_sequence_with_unsupported_feature,
+ validate_prompt,
+)
from airflow.providers.common.compat.sdk import (
DecoratedOperator,
TaskDecorator,
@@ -87,8 +91,13 @@ class _LLMDecoratedOperator(DecoratedOperator, LLMOperator):
self.prompt = self.python_callable(*self.op_args, **kwargs)
- if not isinstance(self.prompt, str) or not self.prompt.strip():
- raise TypeError("The returned value from the @task.llm callable
must be a non-empty string.")
+ validate_prompt(self.prompt, decorator_name="@task.llm")
+ reject_sequence_with_unsupported_feature(
+ self.prompt,
+ decorator_name="@task.llm",
+ feature_name="require_approval",
+ feature_enabled=self.require_approval,
+ )
self.render_template_fields(context)
return LLMOperator.execute(self, context)
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_branch.py
b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_branch.py
index 2dc9194638a..91b7d7640d1 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_branch.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_branch.py
@@ -28,6 +28,7 @@ from collections.abc import Callable, Collection, Mapping,
Sequence
from typing import TYPE_CHECKING, Any, ClassVar
from airflow.providers.common.ai.operators.llm_branch import LLMBranchOperator
+from airflow.providers.common.ai.utils.validation import validate_prompt
from airflow.providers.common.compat.sdk import (
DecoratedOperator,
TaskDecorator,
@@ -87,10 +88,7 @@ class _LLMBranchDecoratedOperator(DecoratedOperator,
LLMBranchOperator):
self.prompt = self.python_callable(*self.op_args, **kwargs)
- if not isinstance(self.prompt, str) or not self.prompt.strip():
- raise TypeError(
- "The returned value from the @task.llm_branch callable must be
a non-empty string."
- )
+ validate_prompt(self.prompt, decorator_name="@task.llm_branch")
self.render_template_fields(context)
return LLMBranchOperator.execute(self, context)
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_file_analysis.py
b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_file_analysis.py
index c9451b3fbee..1ad569d8105 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_file_analysis.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_file_analysis.py
@@ -69,6 +69,11 @@ class _LLMFileAnalysisDecoratedOperator(DecoratedOperator,
LLMFileAnalysisOperat
kwargs = determine_kwargs(self.python_callable, self.op_args, context)
self.prompt = self.python_callable(*self.op_args, **kwargs)
+ # The string-only check is intentional here: the operator builds
+ # request.user_content from prompt + files (see
LLMFileAnalysisOperator.execute),
+ # so multimodal inputs are supplied via the `files` parameter, not the
prompt
+ # itself. The other common.ai decorators accept Sequence[UserContent]
because
+ # they pass self.prompt straight to Agent.run_sync.
if not isinstance(self.prompt, str) or not self.prompt.strip():
raise TypeError(
"The returned value from the @task.llm_file_analysis callable
must be a non-empty string."
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_schema_compare.py
b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_schema_compare.py
index b4538d552e9..2e3a0e48ac9 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_schema_compare.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_schema_compare.py
@@ -28,6 +28,7 @@ from collections.abc import Callable, Collection, Mapping,
Sequence
from typing import TYPE_CHECKING, Any, ClassVar
from airflow.providers.common.ai.operators.llm_schema_compare import
LLMSchemaCompareOperator
+from airflow.providers.common.ai.utils.validation import validate_prompt
from airflow.providers.common.compat.sdk import (
DecoratedOperator,
TaskDecorator,
@@ -87,10 +88,7 @@ class _LLMSchemaCompareDecoratedOperator(DecoratedOperator,
LLMSchemaCompareOper
self.prompt = self.python_callable(*self.op_args, **kwargs)
- if not isinstance(self.prompt, str) or not self.prompt.strip():
- raise TypeError(
- "The returned value from the @task.llm_schema_compare callable
must be a non-empty string."
- )
+ validate_prompt(self.prompt, decorator_name="@task.llm_schema_compare")
self.render_template_fields(context)
return LLMSchemaCompareOperator.execute(self, context)
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_sql.py
b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_sql.py
index d0ebb1a9bb0..5dca4e3d92f 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_sql.py
@@ -28,6 +28,10 @@ from collections.abc import Callable, Collection, Mapping,
Sequence
from typing import TYPE_CHECKING, Any, ClassVar
from airflow.providers.common.ai.operators.llm_sql import LLMSQLQueryOperator
+from airflow.providers.common.ai.utils.validation import (
+ reject_sequence_with_unsupported_feature,
+ validate_prompt,
+)
from airflow.providers.common.compat.sdk import (
DecoratedOperator,
TaskDecorator,
@@ -86,8 +90,13 @@ class _LLMSQLDecoratedOperator(DecoratedOperator,
LLMSQLQueryOperator):
self.prompt = self.python_callable(*self.op_args, **kwargs)
- if not isinstance(self.prompt, str) or not self.prompt.strip():
- raise TypeError("The returned value from the @task.llm_sql
callable must be a non-empty string.")
+ validate_prompt(self.prompt, decorator_name="@task.llm_sql")
+ reject_sequence_with_unsupported_feature(
+ self.prompt,
+ decorator_name="@task.llm_sql",
+ feature_name="require_approval",
+ feature_enabled=self.require_approval,
+ )
self.render_template_fields(context)
# Call LLMSQLQueryOperator.execute directly, not super().execute(),
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py
b/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py
index df5929eedc0..07855340c4b 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py
@@ -84,6 +84,15 @@ class LLMApprovalMixin:
from airflow.sdk.execution_time.hitl import upsert_hitl_detail
from airflow.sdk.timezone import utcnow
+ if not isinstance(self.prompt, str):
+ raise TypeError(
+ "require_approval=True is not supported with a non-string
prompt. "
+ "The approval review body renders the prompt as text; passing
a "
+ "Sequence[UserContent] would expose object reprs (and any
embedded "
+ "bytes) in the human review UI. Return a str prompt, or
disable "
+ "require_approval."
+ )
+
if isinstance(output, BaseModel):
output = output.model_dump_json()
if not isinstance(output, str):
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
index 2660bb408a1..541882c241a 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
@@ -225,6 +225,14 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
return [CachingToolset(wrapped=ts, storage=storage, counter=counter)
for ts in toolsets]
def execute(self, context: Context) -> Any:
+ if self.enable_hitl_review and not isinstance(self.prompt, str):
+ raise TypeError(
+ f"{type(self).__name__}: enable_hitl_review=True is not
supported "
+ f"with a non-string prompt (got {type(self.prompt).__name__}).
"
+ f"The HITL session model requires a string prompt. Return a
str "
+ f"prompt, or disable enable_hitl_review."
+ )
+
self._durable_storage = None
self._durable_counter = None
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
index 95eb6b1442c..4baf834044f 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
@@ -123,6 +123,14 @@ class LLMOperator(BaseOperator, LLMApprovalMixin):
return PydanticAIHook.get_hook(self.llm_conn_id,
hook_params=hook_params)
def execute(self, context: Context) -> Any:
+ if self.require_approval and not isinstance(self.prompt, str):
+ raise TypeError(
+ f"{type(self).__name__}: require_approval=True is not
supported "
+ f"with a non-string prompt (got {type(self.prompt).__name__}).
"
+ f"The approval review body renders the prompt as text. Return
a "
+ f"str prompt, or disable require_approval."
+ )
+
agent: Agent[None, Any] = self.llm_hook.create_agent(
output_type=self.output_type, instructions=self.system_prompt,
**self.agent_params
)
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
index 344b27d122d..370c819fa83 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
@@ -140,6 +140,14 @@ class LLMSQLQueryOperator(LLMOperator):
return hook
def execute(self, context: Context) -> str:
+ if self.require_approval and not isinstance(self.prompt, str):
+ raise TypeError(
+ f"{type(self).__name__}: require_approval=True is not
supported "
+ f"with a non-string prompt (got {type(self.prompt).__name__}).
"
+ f"The approval review body renders the prompt as text. Return
a "
+ f"str prompt, or disable require_approval."
+ )
+
schema_info = self._get_schema_context()
full_system_prompt = self._build_system_prompt(schema_info)
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/utils/validation.py
b/providers/common/ai/src/airflow/providers/common/ai/utils/validation.py
new file mode 100644
index 00000000000..2c06f31b81b
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/utils/validation.py
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Validation helpers for common.ai decorators."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from typing import Any
+
+
+def validate_prompt(value: Any, *, decorator_name: str) -> None:
+ """
+ Validate the prompt returned by a decorator's python_callable.
+
+ Accepted (mirrors pydantic-ai's ``Agent.run_sync`` user_prompt):
+ - non-empty, non-whitespace ``str``
+ - non-empty ``Sequence`` (other than ``str``/``bytes``/``bytearray``)
+ of pydantic-ai ``UserContent`` items; item-level validation is
+ delegated to pydantic-ai at ``Agent.run_sync`` time.
+
+ Raises ``TypeError`` with an actionable message on any other shape.
+ """
+ if isinstance(value, str):
+ if not value.strip():
+ raise TypeError(
+ f"The returned value from the {decorator_name} callable must
be "
+ f"a non-empty string or a non-empty Sequence[UserContent]."
+ )
+ return
+ if isinstance(value, (bytes, bytearray)):
+ raise TypeError(
+ f"The returned value from the {decorator_name} callable must be "
+ f"str or Sequence[UserContent], not {type(value).__name__}."
+ )
+ if isinstance(value, Sequence):
+ if len(value) == 0:
+ raise TypeError(
+ f"The returned value from the {decorator_name} callable must
be "
+ f"a non-empty string or a non-empty Sequence[UserContent]."
+ )
+ for index, item in enumerate(value):
+ if isinstance(item, (bytes, bytearray)):
+ raise TypeError(
+ f"{decorator_name}: Sequence prompt item at index {index}
is "
+ f"{type(item).__name__}; raw bytes are not a valid
UserContent "
+ f"member. Wrap bytes in pydantic-ai's BinaryContent or
upload "
+ f"to object storage and pass an
ImageUrl/AudioUrl/DocumentUrl."
+ )
+ return
+ raise TypeError(
+ f"The returned value from the {decorator_name} callable must be "
+ f"str or Sequence[UserContent], got {type(value).__name__}."
+ )
+
+
+def reject_sequence_with_unsupported_feature(
+ value: Any,
+ *,
+ decorator_name: str,
+ feature_name: str,
+ feature_enabled: bool,
+) -> None:
+ """
+ Preflight check raised before the agent runs.
+
+ Raises ``TypeError`` when *value* is a non-string Sequence and
+ *feature_enabled* is True. Used to fail fast on combinations
+ (e.g., ``enable_hitl_review=True`` + Sequence prompt) that would
+ otherwise fail later -- after the LLM call -- when the downstream
+ HITL/approval consumer tries to stringify the prompt.
+ """
+ if feature_enabled and not isinstance(value, str):
+ raise TypeError(
+ f"{decorator_name}: Sequence[UserContent] prompts are not
supported "
+ f"with {feature_name}=True. Return a str prompt, or disable "
+ f"{feature_name}."
+ )
diff --git a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
index cfe62a38f27..eb6f3fd4312 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
@@ -20,6 +20,7 @@ from unittest.mock import MagicMock, patch
import pytest
from pydantic import BaseModel
+from pydantic_ai.messages import ImageUrl
from airflow.providers.common.ai.decorators.agent import
_AgentDecoratedOperator
from airflow.providers.common.ai.toolsets.logging import LoggingToolset
@@ -60,19 +61,60 @@ class TestAgentDecoratedOperator:
@pytest.mark.parametrize(
"return_value",
- [42, "", " ", None],
- ids=["non-string", "empty", "whitespace", "none"],
+ [42, "", " ", None, b"bytes", bytearray(b"x"), [], ()],
+ ids=["non-string", "empty", "whitespace", "none", "bytes",
"bytearray", "empty-list", "empty-tuple"],
)
def test_execute_raises_on_invalid_prompt(self, return_value):
- """TypeError when the callable returns a non-string or blank string."""
+ """TypeError when the callable returns an unsupported prompt shape."""
op = _AgentDecoratedOperator(
task_id="test",
python_callable=lambda: return_value,
llm_conn_id="my_llm",
)
- with pytest.raises(TypeError, match="non-empty string"):
+ with pytest.raises(TypeError, match="must be"):
op.execute(context={})
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_execute_accepts_sequence_prompt(self, mock_hook_cls):
+ """A non-empty Sequence[UserContent] return value is forwarded to
run_sync as-is."""
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_agent.run_sync.return_value = _make_mock_run_result("ok")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ image = ImageUrl(url="https://example.com/x.png")
+ prompt = ["Describe this:", image]
+
+ def my_prompt():
+ return prompt
+
+ op = _AgentDecoratedOperator(task_id="test",
python_callable=my_prompt, llm_conn_id="my_llm")
+ op.execute(context={})
+
+ assert op.prompt == prompt
+ mock_agent.run_sync.assert_called_once_with(prompt, usage_limits=None)
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_sequence_prompt_with_hitl_review_raises_before_run_sync(self,
mock_hook_cls):
+ """Sequence prompt + enable_hitl_review=True fails before the agent
runs."""
+ from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+
+ if not AIRFLOW_V_3_1_PLUS:
+ pytest.skip("enable_hitl_review requires Airflow >= 3.1.0")
+
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ op = _AgentDecoratedOperator(
+ task_id="test",
+ python_callable=lambda: ["x",
ImageUrl(url="https://example.com/x.png")],
+ llm_conn_id="my_llm",
+ enable_hitl_review=True,
+ )
+ with pytest.raises(TypeError, match="enable_hitl_review=True"):
+ op.execute(context={})
+
+ mock_agent.run_sync.assert_not_called()
+
@patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
def test_execute_merges_op_kwargs_into_callable(self, mock_hook_cls):
"""op_kwargs are resolved by the callable to build the prompt."""
diff --git a/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
index 05768bd7452..67ea067d160 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
@@ -19,6 +19,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
+from pydantic_ai.messages import ImageUrl
from airflow.providers.common.ai.decorators.llm import _LLMDecoratedOperator
@@ -58,19 +59,55 @@ class TestLLMDecoratedOperator:
@pytest.mark.parametrize(
"return_value",
- [42, "", " ", None],
- ids=["non-string", "empty", "whitespace", "none"],
+ [42, "", " ", None, b"bytes", bytearray(b"x"), [], ()],
+ ids=["non-string", "empty", "whitespace", "none", "bytes",
"bytearray", "empty-list", "empty-tuple"],
)
def test_execute_raises_on_invalid_prompt(self, return_value):
- """TypeError when the callable returns a non-string or blank string."""
+ """TypeError when the callable returns an unsupported prompt shape."""
op = _LLMDecoratedOperator(
task_id="test",
python_callable=lambda: return_value,
llm_conn_id="my_llm",
)
- with pytest.raises(TypeError, match="non-empty string"):
+ with pytest.raises(TypeError, match="must be"):
op.execute(context={})
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def test_execute_accepts_sequence_prompt(self, mock_hook_cls):
+ """A non-empty Sequence[UserContent] return value is forwarded to
run_sync as-is."""
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_agent.run_sync.return_value = _make_mock_run_result("ok")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ image = ImageUrl(url="https://example.com/x.png")
+ prompt = ["Describe this:", image]
+
+ def my_prompt():
+ return prompt
+
+ op = _LLMDecoratedOperator(task_id="test", python_callable=my_prompt,
llm_conn_id="my_llm")
+ op.execute(context={})
+
+ assert op.prompt == prompt
+ mock_agent.run_sync.assert_called_once_with(prompt, usage_limits=None)
+
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def
test_sequence_prompt_with_require_approval_raises_before_run_sync(self,
mock_hook_cls):
+ """Sequence prompt + require_approval=True fails before the agent
runs."""
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ op = _LLMDecoratedOperator(
+ task_id="test",
+ python_callable=lambda: ["x",
ImageUrl(url="https://example.com/x.png")],
+ llm_conn_id="my_llm",
+ require_approval=True,
+ )
+ with pytest.raises(TypeError, match="require_approval=True"):
+ op.execute(context={})
+
+ mock_agent.run_sync.assert_not_called()
+
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_merges_op_kwargs_into_callable(self, mock_hook_cls):
"""op_kwargs are resolved by the callable to build the prompt."""
diff --git
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
index c00ed3935e3..023af790d36 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
@@ -20,6 +20,7 @@ from enum import Enum
from unittest.mock import MagicMock, patch
import pytest
+from pydantic_ai.messages import ImageUrl
from airflow.providers.common.ai.decorators.llm_branch import
_LLMBranchDecoratedOperator
from airflow.providers.common.ai.operators.llm_branch import LLMBranchOperator
@@ -71,19 +72,47 @@ class TestLLMBranchDecoratedOperator:
@pytest.mark.parametrize(
"return_value",
- [42, "", " ", None],
- ids=["non-string", "empty", "whitespace", "none"],
+ [42, "", " ", None, b"bytes", bytearray(b"x"), [], ()],
+ ids=["non-string", "empty", "whitespace", "none", "bytes",
"bytearray", "empty-list", "empty-tuple"],
)
def test_execute_raises_on_invalid_prompt(self, return_value):
- """TypeError when the callable returns a non-string or blank string."""
+ """TypeError when the callable returns an unsupported prompt shape."""
op = _LLMBranchDecoratedOperator(
task_id="test",
python_callable=lambda: return_value,
llm_conn_id="my_llm",
)
- with pytest.raises(TypeError, match="non-empty string"):
+ with pytest.raises(TypeError, match="must be"):
op.execute(context={})
+ @patch.object(LLMBranchOperator, "do_branch")
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def test_execute_accepts_sequence_prompt(self, mock_hook_cls,
mock_do_branch):
+ """A non-empty Sequence[UserContent] return value is forwarded to
run_sync as-is."""
+ downstream_enum = Enum("DownstreamTasks", {"positive": "positive"})
+
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_agent.run_sync.return_value =
_make_mock_run_result(downstream_enum.positive)
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+ mock_do_branch.return_value = "positive"
+
+ image = ImageUrl(url="https://example.com/x.png")
+ prompt = ["Route based on this image:", image]
+
+ def my_prompt():
+ return prompt
+
+ op = _LLMBranchDecoratedOperator(
+ task_id="test",
+ python_callable=my_prompt,
+ llm_conn_id="my_llm",
+ )
+ op.downstream_task_ids = {"positive"}
+ op.execute(context={})
+
+ assert op.prompt == prompt
+ mock_agent.run_sync.assert_called_once_with(prompt, usage_limits=None)
+
@patch.object(LLMBranchOperator, "do_branch")
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_merges_op_kwargs_into_callable(self, mock_hook_cls,
mock_do_branch):
diff --git
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
index c271c94be00..df3c3f571c2 100644
---
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
+++
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
@@ -19,6 +19,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
+from pydantic_ai.messages import ImageUrl
from airflow.providers.common.ai.decorators.llm_schema_compare import
_LLMSchemaCompareDecoratedOperator
from airflow.providers.common.ai.operators.llm_schema_compare import (
@@ -82,11 +83,11 @@ class TestLLMSchemaCompareDecoratedOperator:
@pytest.mark.parametrize(
"return_value",
- [42, "", " ", None],
- ids=["non-string", "empty", "whitespace", "none"],
+ [42, "", " ", None, b"bytes", bytearray(b"x"), [], ()],
+ ids=["non-string", "empty", "whitespace", "none", "bytes",
"bytearray", "empty-list", "empty-tuple"],
)
def test_execute_raises_on_invalid_prompt(self, return_value):
- """TypeError when the callable returns a non-string or blank string."""
+ """TypeError when the callable returns an unsupported prompt shape."""
op = _LLMSchemaCompareDecoratedOperator(
task_id="test",
python_callable=lambda: return_value,
@@ -94,9 +95,35 @@ class TestLLMSchemaCompareDecoratedOperator:
db_conn_ids=["postgres_default", "snowflake_default"],
table_names=["test_table"],
)
- with pytest.raises(TypeError, match="non-empty string"):
+ with pytest.raises(TypeError, match="must be"):
op.execute(context={})
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ @patch.object(LLMSchemaCompareOperator, "_build_schema_context",
return_value="mocked schema")
+ def test_execute_accepts_sequence_prompt(self, mock_build_ctx,
mock_hook_cls):
+ """A non-empty Sequence[UserContent] return value is forwarded to
run_sync as-is."""
+ mock_agent = _make_mock_agent(_make_compare_result())
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ image = ImageUrl(url="https://example.com/x.png")
+ prompt = ["Compare these schemas:", image]
+
+ def my_prompt_fn():
+ return prompt
+
+ op = _LLMSchemaCompareDecoratedOperator(
+ task_id="test",
+ python_callable=my_prompt_fn,
+ llm_conn_id="llm_conn",
+ db_conn_ids=["postgres_default", "snowflake_default"],
+ table_names=["test_table"],
+ )
+ op.execute(context={})
+
+ assert op.prompt == prompt
+ forwarded_prompt = mock_agent.run_sync.call_args[0][0]
+ assert forwarded_prompt == prompt
+
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
@patch.object(LLMSchemaCompareOperator, "_build_schema_context",
return_value="mocked schema")
def test_execute_merges_op_kwargs_into_callable(self, mock_build_ctx,
mock_hook_cls):
diff --git
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
index 244db8d3150..5b2e4b6e6e3 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
@@ -19,6 +19,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
+from pydantic_ai.messages import ImageUrl
from airflow.providers.common.ai.decorators.llm_sql import
_LLMSQLDecoratedOperator
@@ -58,19 +59,55 @@ class TestLLMSQLDecoratedOperator:
@pytest.mark.parametrize(
"return_value",
- [42, "", " ", None],
- ids=["non-string", "empty", "whitespace", "none"],
+ [42, "", " ", None, b"bytes", bytearray(b"x"), [], ()],
+ ids=["non-string", "empty", "whitespace", "none", "bytes",
"bytearray", "empty-list", "empty-tuple"],
)
def test_execute_raises_on_invalid_prompt(self, return_value):
- """TypeError when the callable returns a non-string or blank string."""
+ """TypeError when the callable returns an unsupported prompt shape."""
op = _LLMSQLDecoratedOperator(
task_id="test",
python_callable=lambda: return_value,
llm_conn_id="my_llm",
)
- with pytest.raises(TypeError, match="non-empty string"):
+ with pytest.raises(TypeError, match="must be"):
op.execute(context={})
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def test_execute_accepts_sequence_prompt(self, mock_hook_cls):
+ """A non-empty Sequence[UserContent] return value is forwarded to
run_sync as-is."""
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_agent.run_sync.return_value = _make_mock_run_result("SELECT 1")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ image = ImageUrl(url="https://example.com/x.png")
+ prompt = ["Write SQL for this diagram:", image]
+
+ def my_prompt_fn():
+ return prompt
+
+ op = _LLMSQLDecoratedOperator(task_id="test",
python_callable=my_prompt_fn, llm_conn_id="my_llm")
+ op.execute(context={})
+
+ assert op.prompt == prompt
+ mock_agent.run_sync.assert_called_once_with(prompt, usage_limits=None)
+
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def
test_sequence_prompt_with_require_approval_raises_before_run_sync(self,
mock_hook_cls):
+ """Sequence prompt + require_approval=True fails before the agent
runs."""
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ op = _LLMSQLDecoratedOperator(
+ task_id="test",
+ python_callable=lambda: ["x",
ImageUrl(url="https://example.com/x.png")],
+ llm_conn_id="my_llm",
+ require_approval=True,
+ )
+ with pytest.raises(TypeError, match="require_approval=True"):
+ op.execute(context={})
+
+ mock_agent.run_sync.assert_not_called()
+
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_merges_op_kwargs_into_callable(self, mock_hook_cls):
"""op_kwargs are resolved by the callable to build the prompt."""
diff --git a/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py
b/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py
index 08355c3b922..464dfe38986 100644
--- a/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py
+++ b/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py
@@ -142,6 +142,11 @@ class TestDeferForApproval:
call_kwargs = mock_upsert.call_args[1]
assert call_kwargs["params"] == {}
+ def test_raises_on_non_string_prompt(self, context):
+ op = FakeOperator(prompt=["Describe this:", object()]) # type:
ignore[arg-type]
+ with pytest.raises(TypeError, match="non-string prompt"):
+ op.defer_for_approval(context, "output")
+
@patch(UTCNOW_PATH)
@patch(HITL_TRIGGER_PATH, autospec=True)
@patch(UPSERT_HITL_PATH)
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
index c817e12bd9b..5651f6c6393 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
@@ -503,3 +503,30 @@ class TestAgentOperatorDurable:
# run_sync called directly, no override
mock_agent.run_sync.assert_called_once_with("test", usage_limits=None)
+
+
[email protected](
+ not AIRFLOW_V_3_1_PLUS, reason="Human in the loop is only compatible with
Airflow >= 3.1.0"
+)
+class TestAgentOperatorMultimodalPromptGuard:
+ """AgentOperator.execute raises before agent.run_sync when
enable_hitl_review=True
+ and self.prompt is not a string -- covering direct construction and the
native
+ template rendering escape (where a string template renders to a
Sequence)."""
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_execute_rejects_sequence_prompt_with_hitl_review(self,
mock_hook_cls):
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ op = AgentOperator(
+ task_id="t",
+ prompt="placeholder",
+ llm_conn_id="c",
+ enable_hitl_review=True,
+ )
+ op.prompt = ["x", object()] # simulate post-template-render value
+
+ with pytest.raises(TypeError, match="enable_hitl_review=True"):
+ op.execute(context=MagicMock())
+
+ mock_agent.run_sync.assert_not_called()
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
index 165a2c30b7a..076b86250dd 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
@@ -297,3 +297,30 @@ class TestLLMOperatorApproval:
result = op.execute_complete({}, generated_output="original",
event=event)
assert result == "edited"
+
+
[email protected](
+ not AIRFLOW_V_3_1_PLUS, reason="Human in the loop is only compatible with
Airflow >= 3.1.0"
+)
+class TestLLMOperatorMultimodalPromptGuard:
+ """LLMOperator.execute raises before agent.run_sync when require_approval
is True
+ and self.prompt is not a string -- covering direct-operator construction
and the
+ native template rendering escape (where a string template renders to a
Sequence)."""
+
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def test_execute_rejects_sequence_prompt_with_require_approval(self,
mock_hook_cls):
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ op = LLMOperator(
+ task_id="t",
+ prompt="placeholder",
+ llm_conn_id="c",
+ require_approval=True,
+ )
+ op.prompt = ["x", object()] # simulate post-template-render value
+
+ with pytest.raises(TypeError, match="require_approval=True"):
+ op.execute(context=_make_context())
+
+ mock_agent.run_sync.assert_not_called()
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
index 96719e2662e..e8e31c6f5de 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
@@ -655,3 +655,29 @@ class TestLLMSQLQueryOperatorApproval:
result = op.execute_complete({}, generated_output="SELECT 1",
event=event)
assert result == "SELECT 1"
+
+
[email protected](
+ not AIRFLOW_V_3_1_PLUS, reason="Human in the loop is only compatible with
Airflow >= 3.1.0"
+)
+class TestLLMSQLQueryOperatorMultimodalPromptGuard:
+ """LLMSQLQueryOperator.execute raises before agent.run_sync when
require_approval=True
+ and self.prompt is not a string."""
+
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def test_execute_rejects_sequence_prompt_with_require_approval(self,
mock_hook_cls):
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ op = LLMSQLQueryOperator(
+ task_id="t",
+ prompt="placeholder",
+ llm_conn_id="c",
+ require_approval=True,
+ )
+ op.prompt = ["x", object()]
+
+ with pytest.raises(TypeError, match="require_approval=True"):
+ op.execute(context=_make_context())
+
+ mock_agent.run_sync.assert_not_called()
diff --git a/providers/common/ai/tests/unit/common/ai/utils/test_validation.py
b/providers/common/ai/tests/unit/common/ai/utils/test_validation.py
new file mode 100644
index 00000000000..2a927415ec1
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/utils/test_validation.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.common.ai.utils.validation import (
+ reject_sequence_with_unsupported_feature,
+ validate_prompt,
+)
+
+
+class TestValidatePrompt:
+ @pytest.mark.parametrize(
+ "value",
+ ["hello", " hello "],
+ ids=["plain", "with-surrounding-whitespace"],
+ )
+ def test_accepts_non_empty_strings(self, value):
+ validate_prompt(value, decorator_name="@task.agent")
+
+ @pytest.mark.parametrize(
+ "value",
+ ["", " ", "\n\t"],
+ ids=["empty", "spaces", "newlines-tabs"],
+ )
+ def test_rejects_blank_strings(self, value):
+ with pytest.raises(TypeError, match="non-empty string or a non-empty
Sequence"):
+ validate_prompt(value, decorator_name="@task.agent")
+
+ @pytest.mark.parametrize(
+ "value",
+ [
+ ["text", object()],
+ ("text", object()),
+ [object()],
+ ],
+ ids=["list-multi", "tuple-multi", "list-single"],
+ )
+ def test_accepts_non_empty_sequence(self, value):
+ validate_prompt(value, decorator_name="@task.agent")
+
+ @pytest.mark.parametrize(
+ "value",
+ [[], ()],
+ ids=["empty-list", "empty-tuple"],
+ )
+ def test_rejects_empty_sequence(self, value):
+ with pytest.raises(TypeError, match="non-empty string or a non-empty
Sequence"):
+ validate_prompt(value, decorator_name="@task.agent")
+
+ @pytest.mark.parametrize(
+ ("value", "match"),
+ [
+ (None, "got NoneType"),
+ (42, "got int"),
+ ({"key": "value"}, "got dict"),
+ ],
+ ids=["none", "int", "dict"],
+ )
+ def test_rejects_unsupported_scalar(self, value, match):
+ with pytest.raises(TypeError, match=match):
+ validate_prompt(value, decorator_name="@task.agent")
+
+ @pytest.mark.parametrize(
+ ("value", "name"),
+ [
+ (b"bytes", "bytes"),
+ (bytearray(b"x"), "bytearray"),
+ ],
+ ids=["bytes", "bytearray"],
+ )
+ def test_rejects_bytes_like(self, value, name):
+ with pytest.raises(TypeError, match=f"not {name}"):
+ validate_prompt(value, decorator_name="@task.agent")
+
+ def test_decorator_name_appears_in_error(self):
+ with pytest.raises(TypeError, match=r"@task\.llm_sql"):
+ validate_prompt(42, decorator_name="@task.llm_sql")
+
+ @pytest.mark.parametrize(
+ "value",
+ [[b"x"], ["ok", bytearray(b"y")]],
+ ids=["bytes-item", "bytearray-mixed"],
+ )
+ def test_rejects_bytes_like_in_sequence(self, value):
+ with pytest.raises(TypeError, match="raw bytes are not a valid
UserContent"):
+ validate_prompt(value, decorator_name="@task.agent")
+
+
+class TestRejectSequenceWithUnsupportedFeature:
+ def test_noop_when_feature_disabled(self):
+ reject_sequence_with_unsupported_feature(
+ ["x", object()],
+ decorator_name="@task.agent",
+ feature_name="enable_hitl_review",
+ feature_enabled=False,
+ )
+
+ def test_noop_when_value_is_string(self):
+ reject_sequence_with_unsupported_feature(
+ "hello",
+ decorator_name="@task.agent",
+ feature_name="enable_hitl_review",
+ feature_enabled=True,
+ )
+
+ def test_raises_for_sequence_with_feature_enabled(self):
+ with pytest.raises(TypeError, match="enable_hitl_review=True"):
+ reject_sequence_with_unsupported_feature(
+ ["x", object()],
+ decorator_name="@task.agent",
+ feature_name="enable_hitl_review",
+ feature_enabled=True,
+ )