AlejandroMorgante commented on code in PR #68479:
URL: https://github.com/apache/airflow/pull/68479#discussion_r3447446916


##########
providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py:
##########
@@ -126,6 +130,213 @@ def _serialize_job(self, job: Any) -> Any:
         return self.job_serializer_class.to_dict(job)
 
 
+class AgentEngineDeleteTrigger(BaseTrigger):
+    """Trigger that waits until a Vertex AI Agent Engine no longer exists."""
+
+    def __init__(
+        self,
+        project_id: str,
+        location: str,
+        agent_engine_id: str,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        poll_interval: float = 30,
+        timeout: float | None = None,
+        operation_name: str | None = None,
+    ):
+        super().__init__()
+        self.project_id = project_id
+        self.location = location
+        self.agent_engine_id = agent_engine_id
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.poll_interval = poll_interval
+        self.timeout = timeout
+        self.operation_name = operation_name
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            
"airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineDeleteTrigger",
+            {
+                "project_id": self.project_id,
+                "location": self.location,
+                "agent_engine_id": self.agent_engine_id,
+                "gcp_conn_id": self.gcp_conn_id,
+                "impersonation_chain": self.impersonation_chain,
+                "poll_interval": self.poll_interval,
+                "timeout": self.timeout,
+                "operation_name": self.operation_name,
+            },
+        )
+
+    @cached_property
+    def async_hook(self) -> AgentEngineAsyncHook:
+        return AgentEngineAsyncHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        if not self.operation_name:
+            yield TriggerEvent(
+                {
+                    "status": "error",
+                    "message": "Delete Agent Engine operation name is 
required.",
+                    "agent_engine_id": self.agent_engine_id,
+                }
+            )
+            return
+
+        start_time = time.monotonic()
+        try:
+            while True:
+                operation = await self.async_hook.get_agent_engine_operation(
+                    location=self.location,
+                    operation_name=self.operation_name,
+                )
+                if operation.get("done"):
+                    if operation.get("error"):
+                        yield TriggerEvent(
+                            {
+                                "status": "error",
+                                "message": str(operation["error"]),
+                                "agent_engine_id": self.agent_engine_id,
+                            }
+                        )
+                        return
+                    yield TriggerEvent(
+                        {
+                            "status": "success",
+                            "message": "Agent Engine deleted",
+                            "agent_engine_id": self.agent_engine_id,
+                        }
+                    )
+                    return
+
+                if self.timeout is not None and time.monotonic() - start_time 
>= self.timeout:
+                    yield TriggerEvent(
+                        {
+                            "status": "timeout",
+                            "message": (
+                                f"Timed out waiting for Agent Engine 
{self.agent_engine_id} to be deleted"
+                            ),
+                            "agent_engine_id": self.agent_engine_id,
+                        }
+                    )
+                    return
+
+                self.log.info("Waiting for Agent Engine %s to be deleted.", 
self.agent_engine_id)
+                await asyncio.sleep(self.poll_interval)
+        except Exception as err:
+            self.log.exception("Exception occurred while waiting for Agent 
Engine deletion.")
+            yield TriggerEvent(
+                {
+                    "status": "error",
+                    "message": str(err),
+                    "agent_engine_id": self.agent_engine_id,
+                }
+            )
+
+
+class AgentEngineQueryJobTrigger(BaseTrigger):
+    """Trigger that waits until a Vertex AI Agent Engine query job 
completes."""
+
+    def __init__(
+        self,
+        project_id: str,
+        location: str,
+        operation_name: str,
+        config: vertexai_types.CheckQueryJobAgentEngineConfigOrDict | None = 
None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        poll_interval: float = 30,
+        timeout: float | None = None,
+    ):
+        super().__init__()
+        self.project_id = project_id
+        self.location = location
+        self.operation_name = operation_name
+        self.config = config
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.poll_interval = poll_interval
+        self.timeout = timeout
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            
"airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineQueryJobTrigger",
+            {
+                "project_id": self.project_id,
+                "location": self.location,
+                "operation_name": self.operation_name,
+                "config": _serialize_value(self.config),
+                "gcp_conn_id": self.gcp_conn_id,
+                "impersonation_chain": self.impersonation_chain,
+                "poll_interval": self.poll_interval,
+                "timeout": self.timeout,
+            },
+        )
+
+    @cached_property
+    def async_hook(self) -> AgentEngineAsyncHook:
+        return AgentEngineAsyncHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        start_time = time.monotonic()
+        try:
+            while True:
+                query_job = await self.async_hook.check_query_agent_engine_job(
+                    project_id=self.project_id,
+                    location=self.location,
+                    operation_name=self.operation_name,
+                    config=self.config,
+                )
+                status = getattr(query_job, "status", None)
+                if status == "SUCCESS":
+                    yield TriggerEvent(
+                        {
+                            "status": "success",
+                            "message": "Agent Engine query job completed",
+                            "query_job": _serialize_value(query_job),
+                        }
+                    )
+                    return
+                if status == "FAILED":
+                    yield TriggerEvent(
+                        {
+                            "status": "error",
+                            "message": f"Agent Engine query job 
{self.operation_name} failed.",
+                            "query_job": _serialize_value(query_job),
+                        }
+                    )
+                    return
+
+                if self.timeout is not None and time.monotonic() - start_time 
>= self.timeout:
+                    yield TriggerEvent(
+                        {
+                            "status": "timeout",
+                            "message": f"Timed out waiting for Agent Engine 
query job {self.operation_name}",
+                            "operation_name": self.operation_name,
+                        }
+                    )
+                    return
+
+                self.log.info("Waiting for Agent Engine query job %s to 
complete.", self.operation_name)
+                await asyncio.sleep(self.poll_interval)
+        except Exception as err:
+            self.log.exception("Exception occurred while waiting for Agent 
Engine query job.")
+            yield TriggerEvent(
+                {
+                    "status": "error",
+                    "message": str(err),

Review Comment:
   Done, added contextual error message and updated test coverage



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to