This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 621792df82b Add `LLMBranchOperator` and `@task.llm_branch` to
`common.ai` provider (#62740)
621792df82b is described below
commit 621792df82b69c20cc21338fa818a3c358ec3154
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Mar 3 00:16:59 2026 +0000
Add `LLMBranchOperator` and `@task.llm_branch` to `common.ai` provider
(#62740)
- Add type: ignore[misc] for dynamic Enum() construction (mypy
requires a literal second arg)
- Add explicit type annotation for branches variable to avoid
incompatible assignment error
- Match do_branch return type (str | Iterable[str] | None)
---
providers/common/ai/docs/index.rst | 1 +
providers/common/ai/docs/operators/llm_branch.rst | 97 ++++++++++++
providers/common/ai/provider.yaml | 4 +
providers/common/ai/pyproject.toml | 4 +
.../providers/common/ai/decorators/llm_branch.py | 135 +++++++++++++++++
.../common/ai/example_dags/example_llm_branch.py | 152 +++++++++++++++++++
.../providers/common/ai/get_provider_info.py | 6 +
.../providers/common/ai/operators/llm_branch.py | 94 ++++++++++++
.../unit/common/ai/decorators/test_llm_branch.py | 102 +++++++++++++
.../unit/common/ai/operators/test_llm_branch.py | 162 +++++++++++++++++++++
10 files changed, 757 insertions(+)
diff --git a/providers/common/ai/docs/index.rst
b/providers/common/ai/docs/index.rst
index fffcfc494ba..06c600a7805 100644
--- a/providers/common/ai/docs/index.rst
+++ b/providers/common/ai/docs/index.rst
@@ -122,6 +122,7 @@ Dependent package
==================================================================================================================
=================
`apache-airflow-providers-common-compat
<https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_
``common.compat``
`apache-airflow-providers-common-sql
<https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_
``common.sql``
+`apache-airflow-providers-standard
<https://airflow.apache.org/docs/apache-airflow-providers-standard>`_
``standard``
==================================================================================================================
=================
Downloading official packages
diff --git a/providers/common/ai/docs/operators/llm_branch.rst
b/providers/common/ai/docs/operators/llm_branch.rst
new file mode 100644
index 00000000000..9d1bc059a5e
--- /dev/null
+++ b/providers/common/ai/docs/operators/llm_branch.rst
@@ -0,0 +1,97 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _howto/operator:llm_branch:
+
+``LLMBranchOperator``
+=====================
+
+Use
:class:`~airflow.providers.common.ai.operators.llm_branch.LLMBranchOperator`
+for LLM-driven branching — where the LLM decides which downstream task(s) to
+execute.
+
+The operator discovers downstream tasks automatically from the DAG topology
+and presents them to the LLM as a constrained enum via pydantic-ai structured
+output. No text parsing or manual validation is needed.
+
+.. seealso::
+ :ref:`Connection configuration <howto/connection:pydantic_ai>`
+
+Basic Usage
+-----------
+
+Connect the operator to downstream tasks. The LLM chooses which branch to
+execute based on the prompt:
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
+ :language: python
+ :start-after: [START howto_operator_llm_branch_basic]
+ :end-before: [END howto_operator_llm_branch_basic]
+
+Multiple Branches
+-----------------
+
+Set ``allow_multiple_branches=True`` to let the LLM select more than one
+downstream task. All selected branches run; unselected branches are skipped:
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
+ :language: python
+ :start-after: [START howto_operator_llm_branch_multi]
+ :end-before: [END howto_operator_llm_branch_multi]
+
+TaskFlow Decorator
+------------------
+
+The ``@task.llm_branch`` decorator wraps ``LLMBranchOperator``. The function
+returns the prompt string; all other parameters are passed to the operator:
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
+ :language: python
+ :start-after: [START howto_decorator_llm_branch]
+ :end-before: [END howto_decorator_llm_branch]
+
+With multiple branches:
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
+ :language: python
+ :start-after: [START howto_decorator_llm_branch_multi]
+ :end-before: [END howto_decorator_llm_branch_multi]
+
+How It Works
+------------
+
+At execution time, the operator:
+
+1. Reads ``self.downstream_task_ids`` from the DAG topology.
+2. Creates a dynamic ``Enum`` with one member per downstream task ID.
+3. Passes that enum as ``output_type`` to ``pydantic-ai``, constraining the
LLM to
+ valid task IDs only.
+4. Converts the LLM's structured output to task ID string(s) and calls
+ ``do_branch()`` to skip non-selected downstream tasks.
+
+Parameters
+----------
+
+- ``prompt``: The prompt to send to the LLM (operator) or the return value of
the
+ decorated function (decorator).
+- ``llm_conn_id``: Airflow connection ID for the LLM provider.
+- ``model_id``: Model identifier (e.g. ``"openai:gpt-5"``). Overrides the
connection's extra field.
+- ``system_prompt``: System-level instructions for the agent. Supports Jinja
templating.
+- ``allow_multiple_branches``: When ``False`` (default) the LLM returns a
single
+ task ID. When ``True`` the LLM may return one or more task IDs.
+- ``agent_params``: Additional keyword arguments passed to the pydantic-ai
``Agent``
+ constructor (e.g. ``retries``, ``model_settings``). Supports Jinja
templating.
diff --git a/providers/common/ai/provider.yaml
b/providers/common/ai/provider.yaml
index 7ef0acd7c3c..7e51945470c 100644
--- a/providers/common/ai/provider.yaml
+++ b/providers/common/ai/provider.yaml
@@ -33,6 +33,7 @@ integrations:
external-doc-url:
https://airflow.apache.org/docs/apache-airflow-providers-common-ai/
how-to-guide:
- /docs/apache-airflow-providers-common-ai/operators/llm.rst
+ - /docs/apache-airflow-providers-common-ai/operators/llm_branch.rst
- /docs/apache-airflow-providers-common-ai/operators/llm_sql.rst
tags: [ai]
- integration-name: Pydantic AI
@@ -62,10 +63,13 @@ operators:
- integration-name: Common AI
python-modules:
- airflow.providers.common.ai.operators.llm
+ - airflow.providers.common.ai.operators.llm_branch
- airflow.providers.common.ai.operators.llm_sql
task-decorators:
- class-name: airflow.providers.common.ai.decorators.llm.llm_task
name: llm
+ - class-name:
airflow.providers.common.ai.decorators.llm_branch.llm_branch_task
+ name: llm_branch
- class-name: airflow.providers.common.ai.decorators.llm_sql.llm_sql_task
name: llm_sql
diff --git a/providers/common/ai/pyproject.toml
b/providers/common/ai/pyproject.toml
index b8726b0248b..1770af6848b 100644
--- a/providers/common/ai/pyproject.toml
+++ b/providers/common/ai/pyproject.toml
@@ -79,6 +79,9 @@ dependencies = [
"common.sql" = [
"apache-airflow-providers-common-sql"
]
+"standard" = [
+ "apache-airflow-providers-standard"
+]
[dependency-groups]
dev = [
@@ -87,6 +90,7 @@ dev = [
"apache-airflow-devel-common",
"apache-airflow-providers-common-compat",
"apache-airflow-providers-common-sql",
+ "apache-airflow-providers-standard",
# Additional devel dependencies (do not remove this line and add extra
development dependencies)
"sqlglot>=26.0.0",
"apache-airflow-providers-common-sql[datafusion]"
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_branch.py
b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_branch.py
new file mode 100644
index 00000000000..2dc9194638a
--- /dev/null
+++
b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_branch.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+TaskFlow decorator for LLM-driven branching.
+
+The user writes a function that **returns the prompt string**. The decorator
+discovers downstream tasks from the DAG topology and asks the LLM to choose
+which branch(es) to execute using pydantic-ai structured output.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Callable, Collection, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, ClassVar
+
+from airflow.providers.common.ai.operators.llm_branch import LLMBranchOperator
+from airflow.providers.common.compat.sdk import (
+ DecoratedOperator,
+ TaskDecorator,
+ context_merge,
+ task_decorator_factory,
+)
+from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
+from airflow.utils.operator_helpers import determine_kwargs
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class _LLMBranchDecoratedOperator(DecoratedOperator, LLMBranchOperator):
+ """
+ Wraps a callable that returns a prompt for LLM-driven branching.
+
+ The user function is called at execution time to produce the prompt string.
+ All other parameters (``llm_conn_id``, ``system_prompt``,
``allow_multiple_branches``,
+ etc.) are passed through to
+
:class:`~airflow.providers.common.ai.operators.llm_branch.LLMBranchOperator`.
+
+ :param python_callable: A reference to a callable that returns the prompt
string.
+ :param op_args: Positional arguments for the callable.
+ :param op_kwargs: Keyword arguments for the callable.
+ """
+
+ template_fields: Sequence[str] = (
+ *DecoratedOperator.template_fields,
+ *LLMBranchOperator.template_fields,
+ )
+ template_fields_renderers: ClassVar[dict[str, str]] = {
+ **DecoratedOperator.template_fields_renderers,
+ }
+
+ custom_operator_name: str = "@task.llm_branch"
+
+ def __init__(
+ self,
+ *,
+ python_callable: Callable,
+ op_args: Collection[Any] | None = None,
+ op_kwargs: Mapping[str, Any] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ python_callable=python_callable,
+ op_args=op_args,
+ op_kwargs=op_kwargs,
+ prompt=SET_DURING_EXECUTION,
+ **kwargs,
+ )
+
+ def execute(self, context: Context) -> Any:
+ context_merge(context, self.op_kwargs)
+ kwargs = determine_kwargs(self.python_callable, self.op_args, context)
+
+ self.prompt = self.python_callable(*self.op_args, **kwargs)
+
+ if not isinstance(self.prompt, str) or not self.prompt.strip():
+ raise TypeError(
+ "The returned value from the @task.llm_branch callable must be
a non-empty string."
+ )
+
+ self.render_template_fields(context)
+ return LLMBranchOperator.execute(self, context)
+
+
+def llm_branch_task(
+ python_callable: Callable | None = None,
+ **kwargs,
+) -> TaskDecorator:
+ """
+ Wrap a function that returns a prompt into an LLM-driven branching task.
+
+ The function body constructs the prompt. The decorator discovers downstream
+ tasks from the DAG topology and asks the LLM to choose which branch(es)
+ to execute.
+
+ Usage::
+
+ @task.llm_branch(
+ llm_conn_id="openai_default",
+ system_prompt="Route support tickets to the right team.",
+ )
+ def route_ticket(message: str):
+ return f"Route this ticket: {message}"
+
+ With multiple branches::
+
+ @task.llm_branch(
+ llm_conn_id="openai_default",
+ system_prompt="Select all applicable categories.",
+ allow_multiple_branches=True,
+ )
+ def classify(text: str):
+ return f"Classify this text: {text}"
+
+ :param python_callable: Function to decorate.
+ """
+ return task_decorator_factory(
+ python_callable=python_callable,
+ decorated_operator_class=_LLMBranchDecoratedOperator,
+ **kwargs,
+ )
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
new file mode 100644
index 00000000000..c76b68999e7
--- /dev/null
+++
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Example DAGs demonstrating LLMBranchOperator and @task.llm_branch usage."""
+
+from __future__ import annotations
+
+from airflow.providers.common.ai.operators.llm_branch import LLMBranchOperator
+from airflow.providers.common.compat.sdk import dag, task
+
+
+# [START howto_operator_llm_branch_basic]
+@dag
+def example_llm_branch_operator():
+ route = LLMBranchOperator(
+ task_id="route_ticket",
+ prompt="User says: 'My password reset email never arrived.'",
+ llm_conn_id="pydantic_ai_default",
+ system_prompt="Route support tickets to the right team.",
+ )
+
+ @task
+ def handle_billing():
+ return "Handling billing issue"
+
+ @task
+ def handle_auth():
+ return "Handling auth issue"
+
+ @task
+ def handle_general():
+ return "Handling general issue"
+
+ route >> [handle_billing(), handle_auth(), handle_general()]
+
+
+# [END howto_operator_llm_branch_basic]
+
+example_llm_branch_operator()
+
+
+# [START howto_operator_llm_branch_multi]
+@dag
+def example_llm_branch_multi():
+ route = LLMBranchOperator(
+ task_id="classify",
+ prompt="This product is great but shipping was slow and the box was
damaged.",
+ llm_conn_id="pydantic_ai_default",
+ system_prompt="Select all applicable categories for this customer
review.",
+ allow_multiple_branches=True,
+ )
+
+ @task
+ def handle_positive():
+ return "Processing positive feedback"
+
+ @task
+ def handle_shipping():
+ return "Escalating shipping issue"
+
+ @task
+ def handle_packaging():
+ return "Escalating packaging issue"
+
+ route >> [handle_positive(), handle_shipping(), handle_packaging()]
+
+
+# [END howto_operator_llm_branch_multi]
+
+example_llm_branch_multi()
+
+
+# [START howto_decorator_llm_branch]
+@dag
+def example_llm_branch_decorator():
+ @task.llm_branch(
+ llm_conn_id="pydantic_ai_default",
+ system_prompt="Route support tickets to the right team.",
+ )
+ def route_ticket(message: str):
+ return f"Route this support ticket: {message}"
+
+ @task
+ def handle_billing():
+ return "Handling billing issue"
+
+ @task
+ def handle_auth():
+ return "Handling auth issue"
+
+ @task
+ def handle_general():
+ return "Handling general issue"
+
+ route_ticket("I was charged twice for my subscription.") >> [
+ handle_billing(),
+ handle_auth(),
+ handle_general(),
+ ]
+
+
+# [END howto_decorator_llm_branch]
+
+example_llm_branch_decorator()
+
+
+# [START howto_decorator_llm_branch_multi]
+@dag
+def example_llm_branch_decorator_multi():
+ @task.llm_branch(
+ llm_conn_id="pydantic_ai_default",
+ system_prompt="Select all applicable categories for this customer
review.",
+ allow_multiple_branches=True,
+ )
+ def classify_review(review: str):
+ return f"Classify this review: {review}"
+
+ @task
+ def handle_positive():
+ return "Processing positive feedback"
+
+ @task
+ def handle_shipping():
+ return "Escalating shipping issue"
+
+ @task
+ def handle_packaging():
+ return "Escalating packaging issue"
+
+ classify_review("Great product but shipping was slow.") >> [
+ handle_positive(),
+ handle_shipping(),
+ handle_packaging(),
+ ]
+
+
+# [END howto_decorator_llm_branch_multi]
+
+example_llm_branch_decorator_multi()
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
index 091f4b849ac..26d285e6a70 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
@@ -32,6 +32,7 @@ def get_provider_info():
"external-doc-url":
"https://airflow.apache.org/docs/apache-airflow-providers-common-ai/",
"how-to-guide": [
"/docs/apache-airflow-providers-common-ai/operators/llm.rst",
+
"/docs/apache-airflow-providers-common-ai/operators/llm_branch.rst",
"/docs/apache-airflow-providers-common-ai/operators/llm_sql.rst",
],
"tags": ["ai"],
@@ -67,12 +68,17 @@ def get_provider_info():
"integration-name": "Common AI",
"python-modules": [
"airflow.providers.common.ai.operators.llm",
+ "airflow.providers.common.ai.operators.llm_branch",
"airflow.providers.common.ai.operators.llm_sql",
],
}
],
"task-decorators": [
{"class-name":
"airflow.providers.common.ai.decorators.llm.llm_task", "name": "llm"},
+ {
+ "class-name":
"airflow.providers.common.ai.decorators.llm_branch.llm_branch_task",
+ "name": "llm_branch",
+ },
{"class-name":
"airflow.providers.common.ai.decorators.llm_sql.llm_sql_task", "name":
"llm_sql"},
],
}
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_branch.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_branch.py
new file mode 100644
index 00000000000..b7f3028ec92
--- /dev/null
+++
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_branch.py
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""LLM-driven branching operator."""
+
+from __future__ import annotations
+
+from collections.abc import Iterable, Sequence
+from enum import Enum
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.standard.operators.branch import BranchMixIn
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class LLMBranchOperator(LLMOperator, BranchMixIn):
+ """
+ Ask an LLM to choose which downstream task(s) to execute.
+
+ Downstream task IDs are discovered automatically from the DAG topology
+ and presented to the LLM as a constrained enum via pydantic-ai structured
+ output. No text parsing or manual validation is needed.
+
+ :param prompt: The prompt to send to the LLM.
+ :param llm_conn_id: Connection ID for the LLM provider.
+ :param model_id: Model identifier (e.g. ``"openai:gpt-5"``).
+ Overrides the model stored in the connection's extra field.
+ :param system_prompt: System-level instructions for the LLM agent.
+ :param allow_multiple_branches: When ``False`` (default) the LLM returns a
+ single task ID. When ``True`` the LLM may return one or more task IDs.
+ :param agent_params: Additional keyword arguments passed to the pydantic-ai
+ ``Agent`` constructor (e.g. ``retries``, ``model_settings``,
``tools``).
+ """
+
+ inherits_from_skipmixin = True
+
+ template_fields: Sequence[str] = LLMOperator.template_fields
+
+ def __init__(
+ self,
+ *,
+ allow_multiple_branches: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ kwargs.pop("output_type", None)
+ super().__init__(**kwargs)
+ self.allow_multiple_branches = allow_multiple_branches
+
+ def execute(self, context: Context) -> str | Iterable[str] | None:
+ if not self.downstream_task_ids:
+ raise ValueError(
+ f"{self.task_id!r} has no downstream tasks. "
+ "LLMBranchOperator requires at least one downstream task to
branch into."
+ )
+
+ downstream_tasks_enum = Enum( # type: ignore[misc]
+ "DownstreamTasks",
+ {task_id: task_id for task_id in self.downstream_task_ids},
+ )
+ output_type = list[downstream_tasks_enum] if
self.allow_multiple_branches else downstream_tasks_enum
+
+ agent = self.llm_hook.create_agent(
+ output_type=output_type,
+ instructions=self.system_prompt,
+ **self.agent_params,
+ )
+ result = agent.run_sync(self.prompt)
+ output = result.output
+
+ branches: str | list[str]
+ if isinstance(output, list):
+ branches = [item.value for item in output]
+ elif isinstance(output, Enum):
+ branches = output.value
+ else:
+ branches = str(output)
+
+ return self.do_branch(context, branches)
diff --git
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
new file mode 100644
index 00000000000..66620426a3b
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from enum import Enum
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.decorators.llm_branch import
_LLMBranchDecoratedOperator
+from airflow.providers.common.ai.operators.llm_branch import LLMBranchOperator
+
+
+class TestLLMBranchDecoratedOperator:
+ def test_custom_operator_name(self):
+ assert _LLMBranchDecoratedOperator.custom_operator_name ==
"@task.llm_branch"
+
+ @patch.object(LLMBranchOperator, "do_branch")
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def test_execute_calls_callable_and_branches(self, mock_hook_cls,
mock_do_branch):
+ """The callable's return value becomes the LLM prompt, LLM output goes
through do_branch."""
+ downstream_enum = Enum("DownstreamTasks", {"positive": "positive",
"negative": "negative"})
+
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_result = MagicMock(spec=["output"])
+ mock_result.output = downstream_enum.positive
+ mock_agent.run_sync.return_value = mock_result
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_do_branch.return_value = "positive"
+
+ def my_prompt():
+ return "Route this review"
+
+ op = _LLMBranchDecoratedOperator(
+ task_id="test",
+ python_callable=my_prompt,
+ llm_conn_id="my_llm",
+ )
+ op.downstream_task_ids = {"positive", "negative"}
+
+ result = op.execute(context={})
+
+ assert result == "positive"
+ assert op.prompt == "Route this review"
+ mock_agent.run_sync.assert_called_once_with("Route this review")
+ mock_do_branch.assert_called_once()
+
+ @pytest.mark.parametrize(
+ "return_value",
+ [42, "", " ", None],
+ ids=["non-string", "empty", "whitespace", "none"],
+ )
+ def test_execute_raises_on_invalid_prompt(self, return_value):
+ """TypeError when the callable returns a non-string or blank string."""
+ op = _LLMBranchDecoratedOperator(
+ task_id="test",
+ python_callable=lambda: return_value,
+ llm_conn_id="my_llm",
+ )
+ with pytest.raises(TypeError, match="non-empty string"):
+ op.execute(context={})
+
+ @patch.object(LLMBranchOperator, "do_branch")
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def test_execute_merges_op_kwargs_into_callable(self, mock_hook_cls,
mock_do_branch):
+ """op_kwargs are resolved by the callable to build the prompt."""
+ downstream_enum = Enum("DownstreamTasks", {"task_a": "task_a"})
+
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_result = MagicMock(spec=["output"])
+ mock_result.output = downstream_enum.task_a
+ mock_agent.run_sync.return_value = mock_result
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+ def my_prompt(ticket_type):
+ return f"Route this {ticket_type} ticket"
+
+ op = _LLMBranchDecoratedOperator(
+ task_id="test",
+ python_callable=my_prompt,
+ llm_conn_id="my_llm",
+ op_kwargs={"ticket_type": "billing"},
+ )
+ op.downstream_task_ids = {"task_a"}
+
+ op.execute(context={"task_instance": MagicMock()})
+
+ assert op.prompt == "Route this billing ticket"
diff --git
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
new file mode 100644
index 00000000000..d94fc552178
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
@@ -0,0 +1,162 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from enum import Enum
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.ai.operators.llm_branch import LLMBranchOperator
+
+
+class TestLLMBranchOperator:
+ def test_inherits_from_skipmixin_is_true(self):
+ assert LLMBranchOperator.inherits_from_skipmixin is True
+
+ def test_template_fields(self):
+ assert set(LLMBranchOperator.template_fields) ==
set(LLMOperator.template_fields)
+
+ def test_output_type_ignored(self):
+ """Passing output_type= doesn't break anything; it's silently
dropped."""
+ op = LLMBranchOperator(
+ task_id="test",
+ prompt="pick a branch",
+ llm_conn_id="my_llm",
+ output_type=int,
+ )
+ # output_type is overridden to str (the LLMOperator default) since
+ # the real output_type is built dynamically from downstream_task_ids
+ assert op.output_type is str
+
+ @patch.object(LLMBranchOperator, "do_branch")
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def test_execute_single_branch(self, mock_hook_cls, mock_do_branch):
+ """LLM returns a single enum member → do_branch receives a string."""
+ downstream_enum = Enum("DownstreamTasks", {"task_a": "task_a",
"task_b": "task_b"})
+
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_result = MagicMock(spec=["output"])
+ mock_result.output = downstream_enum.task_a
+ mock_agent.run_sync.return_value = mock_result
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_do_branch.return_value = "task_a"
+
+ op = LLMBranchOperator(
+ task_id="test",
+ prompt="Pick a branch",
+ llm_conn_id="my_llm",
+ )
+ op.downstream_task_ids = {"task_a", "task_b"}
+
+ ctx = MagicMock()
+ result = op.execute(ctx)
+
+ assert result == "task_a"
+ mock_do_branch.assert_called_once_with(ctx, "task_a")
+ mock_agent.run_sync.assert_called_once_with("Pick a branch")
+
+ @patch.object(LLMBranchOperator, "do_branch")
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def test_execute_multi_branch(self, mock_hook_cls, mock_do_branch):
+ """allow_multiple_branches=True → LLM returns list of enums →
do_branch receives list."""
+ downstream_enum = Enum(
+ "DownstreamTasks", {"task_a": "task_a", "task_b": "task_b",
"task_c": "task_c"}
+ )
+
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_result = MagicMock(spec=["output"])
+ mock_result.output = [downstream_enum.task_a, downstream_enum.task_c]
+ mock_agent.run_sync.return_value = mock_result
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_do_branch.return_value = ["task_a", "task_c"]
+
+ op = LLMBranchOperator(
+ task_id="test",
+ prompt="Pick branches",
+ llm_conn_id="my_llm",
+ allow_multiple_branches=True,
+ )
+ op.downstream_task_ids = {"task_a", "task_b", "task_c"}
+
+ ctx = MagicMock()
+ result = op.execute(ctx)
+
+ assert result == ["task_a", "task_c"]
+ mock_do_branch.assert_called_once_with(ctx, ["task_a", "task_c"])
+
+ @patch.object(LLMBranchOperator, "do_branch")
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def test_system_prompt_forwarded(self, mock_hook_cls, mock_do_branch):
+ """system_prompt is passed to create_agent(instructions=...)."""
+ downstream_enum = Enum("DownstreamTasks", {"task_a": "task_a"})
+
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_result = MagicMock(spec=["output"])
+ mock_result.output = downstream_enum.task_a
+ mock_agent.run_sync.return_value = mock_result
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+ op = LLMBranchOperator(
+ task_id="test",
+ prompt="Pick",
+ llm_conn_id="my_llm",
+ system_prompt="Route tickets to the right team.",
+ )
+ op.downstream_task_ids = {"task_a"}
+
+ op.execute(MagicMock())
+
+ call_kwargs = mock_hook_cls.return_value.create_agent.call_args
+ assert call_kwargs.kwargs["instructions"] == "Route tickets to the
right team."
+
+ @patch.object(LLMBranchOperator, "do_branch")
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ def test_downstream_task_ids_used_for_enum(self, mock_hook_cls,
mock_do_branch):
+ """The dynamic enum is built from self.downstream_task_ids."""
+ downstream_enum = Enum(
+ "DownstreamTasks", {"billing": "billing", "auth": "auth",
"general": "general"}
+ )
+
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_result = MagicMock(spec=["output"])
+ mock_result.output = downstream_enum.billing
+ mock_agent.run_sync.return_value = mock_result
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+ op = LLMBranchOperator(
+ task_id="test",
+ prompt="Pick",
+ llm_conn_id="my_llm",
+ )
+ op.downstream_task_ids = {"billing", "auth", "general"}
+
+ op.execute(MagicMock())
+
+ output_type =
mock_hook_cls.return_value.create_agent.call_args.kwargs["output_type"]
+ assert {m.value for m in output_type} == {"billing", "auth", "general"}
+
+ def test_execute_raises_on_no_downstream_tasks(self):
+ """ValueError when the operator has no downstream tasks."""
+ op = LLMBranchOperator(
+ task_id="test",
+ prompt="Pick",
+ llm_conn_id="my_llm",
+ )
+ with pytest.raises(ValueError, match="no downstream tasks"):
+ op.execute(MagicMock())