kaxil commented on code in PR #63081:
URL: https://github.com/apache/airflow/pull/63081#discussion_r2908379205


##########
providers/common/ai/src/airflow/providers/common/ai/mixins/hitl_review.py:
##########
@@ -0,0 +1,262 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+import time
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Protocol
+
+from pydantic import BaseModel
+
+from airflow.providers.common.ai.exceptions import HITLMaxIterationsError
+from airflow.providers.common.ai.utils.hitl_review import (
+    XCOM_AGENT_OUTPUT_PREFIX,
+    XCOM_AGENT_SESSION,
+    XCOM_HUMAN_ACTION,
+    AgentSessionData,
+    HumanActionData,
+    SessionStatus,
+)
+
+log = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+    from airflow.sdk import Context
+
+
+class HITLReviewProtocol(Protocol):
+    """Attributes that the host operator must provide."""
+
+    enable_hitl_review: bool
+    max_hitl_iterations: int
+    hitl_timeout: timedelta | None
+    hitl_poll_interval: float
+    prompt: str
+    task_id: str
+    log: Any
+
+
+class HITLReviewMixin:
+    """
+    Mixin that drives an iterative HITL review loop inside ``execute()``.
+
+    After the operator generates its first output, the mixin:
+
+    1. Pushes session metadata and the first agent output to XCom.
+    2. Polls the human action XCom (``airflow_hitl_review_human_action``) at 
``hitl_poll_interval`` seconds.
+    3. When a human sets action to ``changes_requested`` (via the plugin API),
+       calls :meth:`regenerate_with_feedback` and pushes the new agent output.
+    4. When a human sets action to ``approved``, returns the output.
+    5. When a human sets action to ``rejected``, raises a `HITLRejectException`
+
+    The loop stops after ``hitl_timeout``.
+
+    All agent outputs and human feedback are persisted as iteration-keyed
+    XCom entries (``airflow_hitl_review_agent_output_1``, 
``airflow_hitl_review_human_feedback_1``, etc.)
+    for full auditability.
+
+    Operators using this mixin must set:
+
+    - ``enable_hitl_review`` (``bool``)
+    - ``hitl_timeout`` (``timedelta | None``)
+    - ``hitl_poll_interval`` (``float``, seconds)
+    - ``prompt`` (``str``)
+
+    And must implement: meth:`regenerate_with_feedback`.
+    """
+
+    def run_hitl_review(
+        self: HITLReviewProtocol,
+        context: Context,
+        output: Any,
+        *,
+        message_history: Any = None,
+    ) -> str:
+        """
+        Execute the full HITL review loop.
+
+        :param context: Airflow task context.
+        :param output: Initial LLM output (str or BaseModel).
+        :param message_history: Provider-specific conversation state (e.g.
+            pydantic-ai ``list[ModelMessage]``).  Passed to
+            :meth:`regenerate_with_feedback` on each iteration.
+        :returns: The final approved output as a string.
+        :raises HITLMaxIterationsError: When max iterations reached without 
approval.
+        :raises HITLRejectException: When the reviewer rejects the output.
+        :raises HITLTimeoutError: When hitl_timeout elapses with no response.
+        """
+        output_str = self._to_string(output)  # type: ignore[attr-defined]
+        ti = context["task_instance"]
+
+        session = AgentSessionData(
+            status=SessionStatus.PENDING_REVIEW,
+            iteration=1,
+            max_iterations=self.max_hitl_iterations,
+            prompt=self.prompt,
+            current_output=output_str,
+        )
+
+        ti.xcom_push(key=XCOM_AGENT_SESSION, 
value=session.model_dump(mode="json"))
+        ti.xcom_push(key=f"{XCOM_AGENT_OUTPUT_PREFIX}1", value=output_str)
+
+        self.log.info(
+            "Feedback session created for %s/%s/%s (poll every %ds).",
+            ti.dag_id,
+            ti.run_id,
+            ti.task_id,
+            self.hitl_poll_interval,
+        )
+
+        deadline = time.monotonic() + self.hitl_timeout.total_seconds() if 
self.hitl_timeout else None
+
+        return self._poll_loop(  # type: ignore[attr-defined]
+            ti=ti,
+            session=session,
+            message_history=message_history,
+            deadline=deadline,
+        )
+
+    def _poll_loop(
+        self: HITLReviewProtocol,
+        *,
+        ti: Any,
+        session: AgentSessionData,
+        message_history: Any,
+        deadline: float | None,
+    ) -> str:
+        """
+        Block until the session reaches a terminal state.
+
+        This loops until the human approves, rejects, or the timeout/max 
iterations is reached.
+        """
+        from airflow.providers.standard.exceptions import (
+            HITLRejectException,
+            HITLTimeoutError,
+        )
+
+        last_seen_iteration = 0
+
+        while True:
+            if deadline is not None and time.monotonic() > deadline:
+                _session_timeout = AgentSessionData(
+                    status=SessionStatus.TIMEOUT_EXCEEDED,
+                    iteration=session.iteration,
+                    max_iterations=session.max_iterations,
+                    prompt=session.prompt,
+                    current_output=session.current_output,
+                )
+                ti.xcom_push(key=XCOM_AGENT_SESSION, 
value=_session_timeout.model_dump(mode="json"))
+                raise HITLTimeoutError("Task exceeded timeout.")
+
+            time.sleep(self.hitl_poll_interval)
+            try:
+                action_raw = ti.xcom_pull(
+                    key=XCOM_HUMAN_ACTION, task_ids=ti.task_id, 
map_indexes=ti.map_index
+                )
+            except Exception:
+                self.log.warning("Failed to pull XCom", exc_info=True)
+                continue
+
+            if action_raw is None:
+                # Human action may take some time to propagate; it must be 
performed in the UI,
+                # after which the plugin updates XCom with this 
XCOM_HUMAN_ACTION. Until then,
+                # continue looping.
+                continue
+
+            try:
+                if isinstance(action_raw, str):
+                    action = HumanActionData.model_validate_json(action_raw)
+                else:
+                    action = HumanActionData.model_validate(action_raw)
+            except Exception:
+                self.log.warning("Malformed human action XCom: %r", action_raw)
+                continue
+
+            if action.iteration <= last_seen_iteration:
+                continue
+
+            last_seen_iteration = action.iteration
+
+            if action.action == "approve":
+                self.log.info("Output approved at iteration %d.", 
session.iteration)
+                return session.current_output
+
+            if action.action == "reject":
+                raise HITLRejectException(f"Output rejected at iteration 
{session.iteration}.")
+
+            if action.action == "changes_requested":
+                feedback_text = action.feedback or ""
+                if not feedback_text:
+                    self.log.info("Empty feedback with 'Request Changes' — 
treating as approve.")
+                    return session.current_output
+
+                self.log.info(
+                    "Feedback received (iteration %d): %s",
+                    session.iteration,
+                    feedback_text,
+                )
+
+                new_output, message_history = self.regenerate_with_feedback(  
# type: ignore[attr-defined]
+                    feedback=feedback_text,

Review Comment:
   The max-iterations check runs after `regenerate_with_feedback` (line 209). 
When the limit is exceeded, the LLM call has already happened and the output is 
discarded.
   
   I see that `current_output=new_output` on line 221 preserves the result in 
XCom before raising -- if that's intentional (so the UI can show the last 
output even after failure), add a comment explaining it. Otherwise, move the 
check before the regeneration:
   ```python
   if session.iteration >= self.max_hitl_iterations:
       # Already at the limit -- don't burn another LLM call
       ...
       raise HITLMaxIterationsError(...)
   
   new_output, message_history = self.regenerate_with_feedback(...)
   ```



##########
providers/common/ai/src/airflow/providers/common/ai/plugins/hitl_review.py:
##########
@@ -0,0 +1,572 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import logging
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Annotated, Any
+from urllib.parse import urlparse
+
+from fastapi import Depends, FastAPI, HTTPException, Query
+from fastapi.responses import HTMLResponse
+from fastapi.staticfiles import StaticFiles
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+from airflow.api_fastapi.auth.managers.models.resource_details import 
DagAccessEntity
+from airflow.api_fastapi.core_api.security import requires_access_dag
+from airflow.configuration import conf
+from airflow.models.taskinstance import TaskInstance as TI
+from airflow.models.xcom import XComModel
+from airflow.plugins_manager import AirflowPlugin
+from airflow.providers.common.ai.utils.hitl_review import (
+    XCOM_AGENT_OUTPUT_PREFIX,
+    XCOM_AGENT_SESSION,
+    XCOM_HUMAN_ACTION,
+    XCOM_HUMAN_FEEDBACK_PREFIX,
+    AgentSessionData,
+    HITLReviewResponse,
+    HumanActionData,
+    HumanFeedbackRequest,
+    SessionStatus,
+)
+from airflow.utils.session import create_session
+from airflow.utils.state import TaskInstanceState
+
+log = logging.getLogger(__name__)
+
+_PLUGIN_PREFIX = "/hitl-review"
+
+
+def _get_base_url_path(path: str) -> str:
+    """Construct URL path with webserver base_url prefix for non-root 
deployments."""
+    base_url = conf.get("api", "base_url", fallback="/")
+    if base_url.startswith(("http://";, "https://";)):
+        base_path = urlparse(base_url).path
+    else:
+        base_path = base_url
+    base_path = base_path.rstrip("/")
+    return base_path + path
+
+
+def _get_chat_html() -> str:
+    base_prefix = _get_base_url_path(_PLUGIN_PREFIX)
+    static_prefix = _get_base_url_path(f"{_PLUGIN_PREFIX}/static")
+    return _CHAT_HTML_SHELL.replace("__BASE_PREFIX__", base_prefix).replace(
+        "__STATIC_PREFIX__", static_prefix
+    )
+
+
+def _get_session():
+    with create_session(scoped=False) as session:
+        yield session
+
+
+SessionDep = Annotated[Session, Depends(_get_session)]
+
+
+def _get_map_index(q: str = Query("-1", alias="map_index")) -> int:
+    """Parse map_index query; use -1 when placeholder unreplaced (e.g. 
``{MAP_INDEX}``) or invalid."""
+    try:
+        return int(q)
+    except (ValueError, TypeError):
+        return -1
+
+
+MapIndexDep = Annotated[int, Depends(_get_map_index)]
+
+
+def _read_xcom(session: Session, *, dag_id: str, run_id: str, task_id: str, 
map_index: int = -1, key: str):
+    """Read a single XCom value from the database."""
+    row = session.scalars(
+        XComModel.get_many(
+            run_id=run_id,
+            key=key,
+            dag_ids=dag_id,
+            task_ids=task_id,
+            map_indexes=map_index,
+            limit=1,
+        )
+    ).first()
+    if row is None:
+        return None
+    return XComModel.deserialize_value(row)
+
+
+def _read_xcom_by_prefix(
+    session: Session, *, dag_id: str, run_id: str, task_id: str, map_index: 
int = -1, prefix: str
+) -> dict[int, Any]:
+    """Read all iteration-keyed XCom entries matching *prefix* (e.g. 
``airflow_hitl_review_agent_output_``)."""
+    query = select(XComModel.key, XComModel.value).where(
+        XComModel.dag_id == dag_id,
+        XComModel.run_id == run_id,
+        XComModel.task_id == task_id,
+        XComModel.map_index == map_index,
+        XComModel.key.like(f"{prefix}%"),
+    )
+    result: dict[int, Any] = {}
+    for key, value in session.execute(query).all():
+        suffix = key[len(prefix) :]
+        if suffix.isdigit():
+            row = SimpleNamespace(value=value)

Review Comment:
   nit: `SimpleNamespace(value=value)` is a clever workaround to reuse 
`deserialize_value`, but it's fragile if that method ever reads attributes 
beyond `.value`. A short comment explaining why would help future readers:
   ```python
   # deserialize_value expects an object with a .value attribute;
   # wrap the raw column value so we can reuse the standard deserialization 
path.
   row = SimpleNamespace(value=value)
   ```



##########
airflow-core/src/airflow/ui/src/pages/Iframe.tsx:
##########
@@ -38,7 +38,7 @@ export const Iframe = ({
       src = src.replaceAll("{DAG_ID}", dagId);
     }
     if (runId !== undefined) {
-      src = src.replaceAll("{RUN_ID}", runId);
+      src = src.replaceAll("{RUN_ID}", encodeURIComponent(runId));

Review Comment:
   Good fix for `runId` -- this is the root cause of the `+`-to-space mangling.
   
   Should `dagId` and `taskId` also get `encodeURIComponent`? They're validated 
identifiers today so probably safe, but encoding all substitutions would be 
more defensive. `mapIndex` is an integer so it's fine as-is.



##########
providers/common/ai/src/airflow/providers/common/ai/mixins/hitl_review.py:
##########
@@ -0,0 +1,262 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+import time
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Protocol
+
+from pydantic import BaseModel
+
+from airflow.providers.common.ai.exceptions import HITLMaxIterationsError
+from airflow.providers.common.ai.utils.hitl_review import (
+    XCOM_AGENT_OUTPUT_PREFIX,
+    XCOM_AGENT_SESSION,
+    XCOM_HUMAN_ACTION,
+    AgentSessionData,
+    HumanActionData,
+    SessionStatus,
+)
+
+log = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+    from airflow.sdk import Context
+
+
+class HITLReviewProtocol(Protocol):
+    """Attributes that the host operator must provide."""
+
+    enable_hitl_review: bool
+    max_hitl_iterations: int
+    hitl_timeout: timedelta | None
+    hitl_poll_interval: float
+    prompt: str
+    task_id: str
+    log: Any
+
+
+class HITLReviewMixin:
+    """
+    Mixin that drives an iterative HITL review loop inside ``execute()``.
+
+    After the operator generates its first output, the mixin:
+
+    1. Pushes session metadata and the first agent output to XCom.
+    2. Polls the human action XCom (``airflow_hitl_review_human_action``) at 
``hitl_poll_interval`` seconds.
+    3. When a human sets action to ``changes_requested`` (via the plugin API),
+       calls :meth:`regenerate_with_feedback` and pushes the new agent output.
+    4. When a human sets action to ``approved``, returns the output.
+    5. When a human sets action to ``rejected``, raises a `HITLRejectException`
+
+    The loop stops after ``hitl_timeout``.
+
+    All agent outputs and human feedback are persisted as iteration-keyed
+    XCom entries (``airflow_hitl_review_agent_output_1``, 
``airflow_hitl_review_human_feedback_1``, etc.)
+    for full auditability.
+
+    Operators using this mixin must set:
+
+    - ``enable_hitl_review`` (``bool``)
+    - ``hitl_timeout`` (``timedelta | None``)
+    - ``hitl_poll_interval`` (``float``, seconds)
+    - ``prompt`` (``str``)
+
+    And must implement: meth:`regenerate_with_feedback`.
+    """
+
+    def run_hitl_review(
+        self: HITLReviewProtocol,
+        context: Context,
+        output: Any,
+        *,
+        message_history: Any = None,
+    ) -> str:
+        """
+        Execute the full HITL review loop.
+
+        :param context: Airflow task context.
+        :param output: Initial LLM output (str or BaseModel).
+        :param message_history: Provider-specific conversation state (e.g.
+            pydantic-ai ``list[ModelMessage]``).  Passed to
+            :meth:`regenerate_with_feedback` on each iteration.
+        :returns: The final approved output as a string.
+        :raises HITLMaxIterationsError: When max iterations reached without 
approval.
+        :raises HITLRejectException: When the reviewer rejects the output.
+        :raises HITLTimeoutError: When hitl_timeout elapses with no response.
+        """
+        output_str = self._to_string(output)  # type: ignore[attr-defined]
+        ti = context["task_instance"]
+
+        session = AgentSessionData(
+            status=SessionStatus.PENDING_REVIEW,
+            iteration=1,
+            max_iterations=self.max_hitl_iterations,
+            prompt=self.prompt,
+            current_output=output_str,
+        )
+
+        ti.xcom_push(key=XCOM_AGENT_SESSION, 
value=session.model_dump(mode="json"))
+        ti.xcom_push(key=f"{XCOM_AGENT_OUTPUT_PREFIX}1", value=output_str)
+
+        self.log.info(
+            "Feedback session created for %s/%s/%s (poll every %ds).",
+            ti.dag_id,
+            ti.run_id,
+            ti.task_id,
+            self.hitl_poll_interval,
+        )
+
+        deadline = time.monotonic() + self.hitl_timeout.total_seconds() if 
self.hitl_timeout else None
+
+        return self._poll_loop(  # type: ignore[attr-defined]
+            ti=ti,
+            session=session,
+            message_history=message_history,
+            deadline=deadline,
+        )
+
+    def _poll_loop(
+        self: HITLReviewProtocol,
+        *,
+        ti: Any,
+        session: AgentSessionData,
+        message_history: Any,
+        deadline: float | None,
+    ) -> str:
+        """
+        Block until the session reaches a terminal state.
+
+        This loops until the human approves, rejects, or the timeout/max 
iterations is reached.
+        """
+        from airflow.providers.standard.exceptions import (
+            HITLRejectException,
+            HITLTimeoutError,
+        )
+
+        last_seen_iteration = 0
+
+        while True:
+            if deadline is not None and time.monotonic() > deadline:
+                _session_timeout = AgentSessionData(
+                    status=SessionStatus.TIMEOUT_EXCEEDED,
+                    iteration=session.iteration,
+                    max_iterations=session.max_iterations,
+                    prompt=session.prompt,
+                    current_output=session.current_output,

Review Comment:
   nit: `time.sleep` is at the top of the loop, so the first XCom check is 
always delayed by `hitl_poll_interval` (default 10s). Moving the sleep to the 
bottom (after the XCom check) would pick up actions that arrive between session 
creation and the first poll without any artificial delay.
   
   Not a big deal in practice since a human reviewer takes much longer than 
10s, but it'd help in automated testing scenarios.



##########
providers/common/ai/src/airflow/providers/common/ai/utils/hitl_review.py:
##########
@@ -0,0 +1,170 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Shared data models, exceptions, and XCom key constants for HITL Review.
+
+Used by both the API-server-side plugin (``plugins.hitl_review``) and the
+worker-side operator mixin (``mixins.hitl_review``).  Depends only on
+``pydantic`` and the standard library.
+
+**Storage**: all session state is persisted as XCom entries on the running
+task instance.  See the *XCom key constants* below for the key naming scheme.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime, timezone
+from enum import Enum
+from typing import Any, Literal
+
+from pydantic import BaseModel, Field
+
+HumanActionType = Literal["approve", "reject", "changes_requested"]
+
+"""
+These xcom keys are reserved for agentic operator with HITL feedback loop.
+"""
+

Review Comment:
   `role: str` accepts any string. The TS side (`feedback.ts`) types this as 
`"assistant" | "human"`. For consistency and to catch bad data early:
   ```python
   role: Literal["assistant", "human"]
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to