kaxil commented on code in PR #62898: URL: https://github.com/apache/airflow/pull/62898#discussion_r2892123301
########## providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py: ########## @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Protocol + +from pydantic import BaseModel + +from airflow.providers.common.compat.sdk import AirflowException + +log = logging.getLogger(__name__) + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class ApprovalFailedException(AirflowException): + """Failed to approve.""" + + +class ApprovalRejectionException(AirflowException): + """Rejected by the reviewer.""" + + +class DeferForApprovalProtocol(Protocol): + """Protocol for defer for approval mixin.""" + + approval_timeout: timedelta | None + allow_modifications: bool + prompt: str + task_id: str + defer: Any + + +class LLMApprovalMixin: + """ + Mixin that pauses an operator for human review before returning output. + + When ``require_approval=True`` on the operator, the generated output is + presented to a human reviewer via the Airflow Human-in-the-Loop (HITL) + interface. The task defers until the reviewer approves or rejects. + + If ``allow_modifications=True``, the reviewer can also edit the output + before approving. The (possibly modified) output is then returned as the + task result. + + Operators that use this mixin must set the following attributes: + + - ``require_approval`` (``bool``) + - ``allow_modifications`` (``bool``) + - ``approval_timeout`` (``timedelta | None``) + - ``prompt`` (``str``) + """ + + APPROVE = "Approve" + REJECT = "Reject" + + def defer_for_approval( + self: DeferForApprovalProtocol, + context: Context, + output: Any, + *, + subject: str | None = None, + body: str | None = None, + ) -> None: + """ + Write HITL detail, then defer to HITLTrigger for human review. + + :param context: Airflow task context. + :param output: The generated output to present for review. + :param subject: Headline shown on the Required Actions page. + Defaults to ``"Review output for task `<task_id>`"``. + :param body: Markdown body shown below the headline. + Defaults to the prompt and output wrapped in a code block. + """ + from airflow.providers.standard.triggers.hitl import HITLTrigger + from airflow.sdk.execution_time.hitl import upsert_hitl_detail + from airflow.sdk.timezone import utcnow + + if isinstance(output, BaseModel): + output = output.model_dump_json() + if not isinstance(output, str): + # Always make string output so that when comparing in the execute_complete matches + output = str(output) + + ti_id = context["task_instance"].id + timeout_datetime = utcnow() + self.approval_timeout if self.approval_timeout else None + + if subject is None: + subject = f"Review output for task `{self.task_id}`" + + if body is None: + body = f"```\nPrompt: {self.prompt}\n\n{output}\n```" + + hitl_params: dict[str, dict[str, Any]] = {} + if self.allow_modifications: + hitl_params = { + "output": { + "value": output, + "description": "Edit the output before approving (optional).", + "schema": {"type": "string"}, + }, + } + + upsert_hitl_detail( + ti_id=ti_id, + options=[LLMApprovalMixin.APPROVE, LLMApprovalMixin.REJECT], + subject=subject, + body=body, + defaults=None, + multiple=False, + params=hitl_params, + ) + + self.defer( + trigger=HITLTrigger( + ti_id=ti_id, + options=[LLMApprovalMixin.APPROVE, LLMApprovalMixin.REJECT], + defaults=None, + params=hitl_params, + multiple=False, + timeout_datetime=timeout_datetime, + ), + method_name="execute_complete", + kwargs={"generated_output": output}, + timeout=self.approval_timeout, + ) + + def execute_complete(self, context: Context, generated_output: str, event: dict[str, Any]) -> str: + """ + Resume after human review. + + Called automatically by Airflow when the HITL trigger fires. + Returns the original or reviewer-modified output on approval. + + :param context: Airflow task context. + :param generated_output: The output that was deferred for review. + :param event: Trigger event payload containing ``chosen_options``, + ``params_input``, and ``responded_by_user``. + :raises ApprovalRejectionException: If the reviewer rejected the output. + :raises ApprovalFailedException: If the trigger reported an error. + :raises HITLTimeoutError: If the approval timed out. + """ + from airflow.providers.standard.exceptions import HITLTimeoutError + + if "error" in event: + error_type = event.get("error_type", "unknown") + if error_type == "timeout": + raise HITLTimeoutError(f"Approval timed out: {event['error']}") + raise ApprovalFailedException(f"Approval failed: {event['error']}") + + responded_by_user = event.get("responded_by_user") + chosen = event.get("chosen_options", []) + if self.APPROVE not in chosen: + raise ApprovalRejectionException(f"Output was rejected by the reviewer {responded_by_user}.") + + output = generated_output + params_input: dict[str, Any] = event.get("params_input") or {} + + # If params has data means technically its allowed modifying see above defer call + if params_input: + modified = params_input.get("output") + if modified and modified != generated_output: Review Comment: If a reviewer intentionally clears the output to an empty string, the truthiness check on `modified` causes their edit to be silently discarded and the original output is returned instead. If this is intentional (preventing accidental blank submissions), a comment would help. Otherwise, consider checking `modified is not None` instead of just `modified`. ########## providers/common/ai/src/airflow/providers/common/ai/operators/llm.py: ########## @@ -96,6 +112,11 @@ def execute(self, context: Context) -> Any: log_run_summary(self.log, result) output = result.output + if self.require_approval: + self.defer_for_approval(context, output) + return None Review Comment: `self.defer()` raises `TaskDeferred` (a `BaseException`), so this `return None` is unreachable. The `-> str | None` return type on `LLMSQLQueryOperator.execute` is also misleading since the function never actually returns `None`. You can drop the `return None` lines and keep the original return types. Same applies to `llm_sql.py:160`. ########## providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py: ########## @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Protocol + +from pydantic import BaseModel + +from airflow.providers.common.compat.sdk import AirflowException + +log = logging.getLogger(__name__) + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class ApprovalFailedException(AirflowException): + """Failed to approve.""" + + +class ApprovalRejectionException(AirflowException): + """Rejected by the reviewer.""" + + +class DeferForApprovalProtocol(Protocol): + """Protocol for defer for approval mixin.""" + + approval_timeout: timedelta | None + allow_modifications: bool + prompt: str + task_id: str + defer: Any + + +class LLMApprovalMixin: + """ + Mixin that pauses an operator for human review before returning output. + + When ``require_approval=True`` on the operator, the generated output is + presented to a human reviewer via the Airflow Human-in-the-Loop (HITL) + interface. The task defers until the reviewer approves or rejects. + + If ``allow_modifications=True``, the reviewer can also edit the output + before approving. The (possibly modified) output is then returned as the + task result. + + Operators that use this mixin must set the following attributes: + + - ``require_approval`` (``bool``) + - ``allow_modifications`` (``bool``) + - ``approval_timeout`` (``timedelta | None``) + - ``prompt`` (``str``) + """ + + APPROVE = "Approve" + REJECT = "Reject" + + def defer_for_approval( + self: DeferForApprovalProtocol, + context: Context, + output: Any, + *, + subject: str | None = None, + body: str | None = None, + ) -> None: + """ + Write HITL detail, then defer to HITLTrigger for human review. + + :param context: Airflow task context. + :param output: The generated output to present for review. + :param subject: Headline shown on the Required Actions page. + Defaults to ``"Review output for task `<task_id>`"``. + :param body: Markdown body shown below the headline. + Defaults to the prompt and output wrapped in a code block. + """ + from airflow.providers.standard.triggers.hitl import HITLTrigger + from airflow.sdk.execution_time.hitl import upsert_hitl_detail + from airflow.sdk.timezone import utcnow + + if isinstance(output, BaseModel): + output = output.model_dump_json() + if not isinstance(output, str): + # Always make string output so that when comparing in the execute_complete matches + output = str(output) + + ti_id = context["task_instance"].id + timeout_datetime = utcnow() + self.approval_timeout if self.approval_timeout else None + + if subject is None: + subject = f"Review output for task `{self.task_id}`" + + if body is None: + body = f"```\nPrompt: {self.prompt}\n\n{output}\n```" + + hitl_params: dict[str, dict[str, Any]] = {} + if self.allow_modifications: + hitl_params = { + "output": { + "value": output, + "description": "Edit the output before approving (optional).", + "schema": {"type": "string"}, + }, + } + + upsert_hitl_detail( + ti_id=ti_id, + options=[LLMApprovalMixin.APPROVE, LLMApprovalMixin.REJECT], + subject=subject, + body=body, + defaults=None, + multiple=False, + params=hitl_params, + ) + + self.defer( + trigger=HITLTrigger( + ti_id=ti_id, + options=[LLMApprovalMixin.APPROVE, LLMApprovalMixin.REJECT], + defaults=None, + params=hitl_params, + multiple=False, + timeout_datetime=timeout_datetime, + ), + method_name="execute_complete", + kwargs={"generated_output": output}, + timeout=self.approval_timeout, + ) + + def execute_complete(self, context: Context, generated_output: str, event: dict[str, Any]) -> str: + """ + Resume after human review. + + Called automatically by Airflow when the HITL trigger fires. + Returns the original or reviewer-modified output on approval. + + :param context: Airflow task context. + :param generated_output: The output that was deferred for review. + :param event: Trigger event payload containing ``chosen_options``, + ``params_input``, and ``responded_by_user``. + :raises ApprovalRejectionException: If the reviewer rejected the output. + :raises ApprovalFailedException: If the trigger reported an error. + :raises HITLTimeoutError: If the approval timed out. + """ + from airflow.providers.standard.exceptions import HITLTimeoutError + + if "error" in event: + error_type = event.get("error_type", "unknown") + if error_type == "timeout": + raise HITLTimeoutError(f"Approval timed out: {event['error']}") + raise ApprovalFailedException(f"Approval failed: {event['error']}") + + responded_by_user = event.get("responded_by_user") Review Comment: A malformed trigger event missing `chosen_options` would silently be treated as a rejection (empty list → APPROVE not in it). The standard `HITLOperator.execute_complete` uses `event["chosen_options"]` (no default) to fail fast on bad events. Consider using `event["chosen_options"]` to match the standard pattern and surface bugs early. ########## providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm.py: ########## @@ -114,3 +114,24 @@ def extract(text: str): # [END howto_decorator_llm_structured] example_llm_decorator_structured() + + +# [START howto_operator_llm_approval] Review Comment: Nit: `timedelta` is already used at the top level in other example DAGs. Move this import to the top of the file for consistency. ########## providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py: ########## @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Protocol + +from pydantic import BaseModel + +from airflow.providers.common.compat.sdk import AirflowException + +log = logging.getLogger(__name__) + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class ApprovalFailedException(AirflowException): + """Failed to approve.""" + + +class ApprovalRejectionException(AirflowException): + """Rejected by the reviewer.""" + + +class DeferForApprovalProtocol(Protocol): + """Protocol for defer for approval mixin.""" + + approval_timeout: timedelta | None + allow_modifications: bool + prompt: str + task_id: str + defer: Any + + +class LLMApprovalMixin: + """ + Mixin that pauses an operator for human review before returning output. + + When ``require_approval=True`` on the operator, the generated output is + presented to a human reviewer via the Airflow Human-in-the-Loop (HITL) + interface. The task defers until the reviewer approves or rejects. + + If ``allow_modifications=True``, the reviewer can also edit the output + before approving. The (possibly modified) output is then returned as the + task result. + + Operators that use this mixin must set the following attributes: + + - ``require_approval`` (``bool``) + - ``allow_modifications`` (``bool``) + - ``approval_timeout`` (``timedelta | None``) + - ``prompt`` (``str``) + """ + + APPROVE = "Approve" + REJECT = "Reject" + + def defer_for_approval( + self: DeferForApprovalProtocol, + context: Context, + output: Any, + *, + subject: str | None = None, + body: str | None = None, + ) -> None: + """ + Write HITL detail, then defer to HITLTrigger for human review. + + :param context: Airflow task context. + :param output: The generated output to present for review. + :param subject: Headline shown on the Required Actions page. + Defaults to ``"Review output for task `<task_id>`"``. + :param body: Markdown body shown below the headline. + Defaults to the prompt and output wrapped in a code block. + """ + from airflow.providers.standard.triggers.hitl import HITLTrigger + from airflow.sdk.execution_time.hitl import upsert_hitl_detail + from airflow.sdk.timezone import utcnow + + if isinstance(output, BaseModel): + output = output.model_dump_json() + if not isinstance(output, str): + # Always make string output so that when comparing in the execute_complete matches + output = str(output) + + ti_id = context["task_instance"].id + timeout_datetime = utcnow() + self.approval_timeout if self.approval_timeout else None + + if subject is None: + subject = f"Review output for task `{self.task_id}`" + + if body is None: + body = f"```\nPrompt: {self.prompt}\n\n{output}\n```" + + hitl_params: dict[str, dict[str, Any]] = {} + if self.allow_modifications: + hitl_params = { + "output": { + "value": output, + "description": "Edit the output before approving (optional).", + "schema": {"type": "string"}, + }, + } + + upsert_hitl_detail( + ti_id=ti_id, + options=[LLMApprovalMixin.APPROVE, LLMApprovalMixin.REJECT], + subject=subject, + body=body, + defaults=None, + multiple=False, + params=hitl_params, + ) + + self.defer( + trigger=HITLTrigger( + ti_id=ti_id, + options=[LLMApprovalMixin.APPROVE, LLMApprovalMixin.REJECT], + defaults=None, + params=hitl_params, + multiple=False, + timeout_datetime=timeout_datetime, + ), + method_name="execute_complete", + kwargs={"generated_output": output}, + timeout=self.approval_timeout, + ) + + def execute_complete(self, context: Context, generated_output: str, event: dict[str, Any]) -> str: + """ + Resume after human review. + + Called automatically by Airflow when the HITL trigger fires. + Returns the original or reviewer-modified output on approval. + + :param context: Airflow task context. + :param generated_output: The output that was deferred for review. + :param event: Trigger event payload containing ``chosen_options``, + ``params_input``, and ``responded_by_user``. + :raises ApprovalRejectionException: If the reviewer rejected the output. + :raises ApprovalFailedException: If the trigger reported an error. + :raises HITLTimeoutError: If the approval timed out. + """ + from airflow.providers.standard.exceptions import HITLTimeoutError + + if "error" in event: + error_type = event.get("error_type", "unknown") + if error_type == "timeout": + raise HITLTimeoutError(f"Approval timed out: {event['error']}") + raise ApprovalFailedException(f"Approval failed: {event['error']}") + + responded_by_user = event.get("responded_by_user") + chosen = event.get("chosen_options", []) + if self.APPROVE not in chosen: + raise ApprovalRejectionException(f"Output was rejected by the reviewer {responded_by_user}.") + + output = generated_output + params_input: dict[str, Any] = event.get("params_input") or {} + + # If params has data means technically its allowed modifying see above defer call Review Comment: Nit: this reads like auto-generated text. Suggest: "If the reviewer provided modified output, return their version." ########## providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py: ########## @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Protocol + +from pydantic import BaseModel + +from airflow.providers.common.compat.sdk import AirflowException + +log = logging.getLogger(__name__) + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class ApprovalFailedException(AirflowException): + """Failed to approve.""" + + +class ApprovalRejectionException(AirflowException): Review Comment: The standard provider already has `HITLRejectException` in `providers/standard/exceptions.py` for the same purpose. Reusing it would give users a single exception type to catch for HITL rejections regardless of whether it came from `ApprovalOperator` or `LLMOperator`. Is there a reason to have a separate exception here? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
