gopidesupavan commented on code in PR #63081: URL: https://github.com/apache/airflow/pull/63081#discussion_r2901552178
########## providers/common/ai/src/airflow/providers/common/ai/mixins/hitl_review.py: ########## @@ -0,0 +1,227 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +import time +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Protocol + +from pydantic import BaseModel + +from airflow.providers.common.ai.utils.hitl_review import ( + XCOM_AGENT_OUTPUT_PREFIX, + XCOM_AGENT_SESSION, + XCOM_HUMAN_ACTION, + AgentSessionData, + HumanActionData, + SessionStatus, +) + +log = logging.getLogger(__name__) + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class HITLReviewProtocol(Protocol): + """Attributes that the host operator must provide.""" + + enable_hitl_review: bool + hitl_timeout: timedelta | None + hitl_poll_interval: float + prompt: str + task_id: str + log: Any + + +class HITLReviewMixin: + """ + Mixin that drives an iterative HITL review loop inside ``execute()``. + + After the operator generates its first output, the mixin: + + 1. Pushes session metadata and the first agent output to XCom. + 2. Polls the human action XCom (``airflow_hitl_review_human_action``) at ``hitl_poll_interval`` seconds. + 3. When a human sets action to ``changes_requested`` (via the plugin API), + calls :meth:`regenerate_with_feedback` and pushes the new agent output. + 4. When a human sets action to ``approved``, returns the output. + 5. When a human sets action to ``rejected``, raises a `HITLRejectException` + + The loop stops after ``hitl_timeout``. + + All agent outputs and human feedback are persisted as iteration-keyed + XCom entries (``airflow_hitl_review_agent_output_1``, ``airflow_hitl_review_human_feedback_1``, etc.) for full + auditability. + + Operators using this mixin must set: + + - ``enable_hitl_review`` (``bool``) + - ``hitl_timeout`` (``timedelta | None``) + - ``hitl_poll_interval`` (``float``, seconds) + - ``prompt`` (``str``) + + And must implement: meth:`regenerate_with_feedback`. + """ + + def run_hitl_review( + self: HITLReviewProtocol, + context: Context, + output: Any, + *, + message_history: Any = None, + ) -> str: + """ + Execute the full HITL review loop. + + :param context: Airflow task context. + :param output: Initial LLM output (str or BaseModel). + :param message_history: Provider-specific conversation state (e.g. + pydantic-ai ``list[ModelMessage]``). Passed to + :meth:`regenerate_with_feedback` on each iteration. + :returns: The final approved (or max-iteration) output as a string. + """ + output_str = self._to_string(output) + ti = context["task_instance"] + + session = AgentSessionData( + status=SessionStatus.PENDING_REVIEW, + iteration=1, + prompt=self.prompt, + current_output=output_str, + ) + + ti.xcom_push(key=XCOM_AGENT_SESSION, value=session.model_dump(mode="json")) + ti.xcom_push(key=f"{XCOM_AGENT_OUTPUT_PREFIX}1", value=output_str) + + self.log.info( + "Feedback session created for %s/%s/%s (poll every %ds).", + ti.dag_id, + ti.run_id, + ti.task_id, + self.hitl_poll_interval, + ) + + deadline = time.monotonic() + self.hitl_timeout.total_seconds() if self.hitl_timeout else None + + return self._poll_loop( + ti=ti, + session=session, + message_history=message_history, + deadline=deadline, + ) + + def _poll_loop( + self: HITLReviewProtocol, Review Comment: have update it and failing the process when max_hitl_iterations reached hope this fine? do you suggest we can return last response? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
