AlejandroMorgante commented on code in PR #68479:
URL: https://github.com/apache/airflow/pull/68479#discussion_r3447432936
##########
providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py:
##########
@@ -126,6 +130,213 @@ def _serialize_job(self, job: Any) -> Any:
return self.job_serializer_class.to_dict(job)
+class AgentEngineDeleteTrigger(BaseTrigger):
+ """Trigger that waits until a Vertex AI Agent Engine no longer exists."""
+
+ def __init__(
+ self,
+ project_id: str,
+ location: str,
+ agent_engine_id: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: float = 30,
+ timeout: float | None = None,
+ operation_name: str | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.location = location
+ self.agent_engine_id = agent_engine_id
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_interval = poll_interval
+ self.timeout = timeout
+ self.operation_name = operation_name
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineDeleteTrigger",
+ {
+ "project_id": self.project_id,
+ "location": self.location,
+ "agent_engine_id": self.agent_engine_id,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_interval": self.poll_interval,
+ "timeout": self.timeout,
+ "operation_name": self.operation_name,
+ },
+ )
+
+ @cached_property
+ def async_hook(self) -> AgentEngineAsyncHook:
+ return AgentEngineAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ if not self.operation_name:
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": "Delete Agent Engine operation name is
required.",
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ start_time = time.monotonic()
+ try:
+ while True:
+ operation = await self.async_hook.get_agent_engine_operation(
+ location=self.location,
+ operation_name=self.operation_name,
+ )
+ if operation.get("done"):
+ if operation.get("error"):
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(operation["error"]),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": "Agent Engine deleted",
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ if self.timeout is not None and time.monotonic() - start_time
>= self.timeout:
+ yield TriggerEvent(
+ {
+ "status": "timeout",
+ "message": (
+ f"Timed out waiting for Agent Engine
{self.agent_engine_id} to be deleted"
+ ),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ self.log.info("Waiting for Agent Engine %s to be deleted.",
self.agent_engine_id)
+ await asyncio.sleep(self.poll_interval)
+ except Exception as err:
+ self.log.exception("Exception occurred while waiting for Agent
Engine deletion.")
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(err),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+
+
+class AgentEngineQueryJobTrigger(BaseTrigger):
+ """Trigger that waits until a Vertex AI Agent Engine query job
completes."""
+
+ def __init__(
+ self,
+ project_id: str,
+ location: str,
+ operation_name: str,
+ config: vertexai_types.CheckQueryJobAgentEngineConfigOrDict | None =
None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: float = 30,
+ timeout: float | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.location = location
+ self.operation_name = operation_name
+ self.config = config
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_interval = poll_interval
+ self.timeout = timeout
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineQueryJobTrigger",
+ {
+ "project_id": self.project_id,
+ "location": self.location,
+ "operation_name": self.operation_name,
+ "config": _serialize_value(self.config),
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_interval": self.poll_interval,
+ "timeout": self.timeout,
+ },
+ )
+
+ @cached_property
+ def async_hook(self) -> AgentEngineAsyncHook:
+ return AgentEngineAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ start_time = time.monotonic()
+ try:
+ while True:
+ query_job = await self.async_hook.check_query_agent_engine_job(
+ project_id=self.project_id,
+ location=self.location,
+ operation_name=self.operation_name,
+ config=self.config,
+ )
+ status = getattr(query_job, "status", None)
+ if status == "SUCCESS":
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": "Agent Engine query job completed",
+ "query_job": _serialize_value(query_job),
+ }
+ )
+ return
+ if status == "FAILED":
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Agent Engine query job
{self.operation_name} failed.",
+ "query_job": _serialize_value(query_job),
+ }
+ )
+ return
Review Comment:
Good catch. I updated both paths now:
- the deferrable trigger treats any status other than `None`/`RUNNING` as
terminal and returns an error event - the synchronous
`wait_for_query_agent_engine_job()` path now raises immediately for unexpected
statuses instead of warning and polling until timeout
I also added regression coverage for `CANCELLED` in both paths.
##########
providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py:
##########
@@ -126,6 +130,213 @@ def _serialize_job(self, job: Any) -> Any:
return self.job_serializer_class.to_dict(job)
+class AgentEngineDeleteTrigger(BaseTrigger):
+ """Trigger that waits until a Vertex AI Agent Engine no longer exists."""
+
+ def __init__(
+ self,
+ project_id: str,
+ location: str,
+ agent_engine_id: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: float = 30,
+ timeout: float | None = None,
+ operation_name: str | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.location = location
+ self.agent_engine_id = agent_engine_id
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_interval = poll_interval
+ self.timeout = timeout
+ self.operation_name = operation_name
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineDeleteTrigger",
+ {
+ "project_id": self.project_id,
+ "location": self.location,
+ "agent_engine_id": self.agent_engine_id,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_interval": self.poll_interval,
+ "timeout": self.timeout,
+ "operation_name": self.operation_name,
+ },
+ )
+
+ @cached_property
+ def async_hook(self) -> AgentEngineAsyncHook:
+ return AgentEngineAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ if not self.operation_name:
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": "Delete Agent Engine operation name is
required.",
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ start_time = time.monotonic()
+ try:
+ while True:
+ operation = await self.async_hook.get_agent_engine_operation(
+ location=self.location,
+ operation_name=self.operation_name,
+ )
+ if operation.get("done"):
+ if operation.get("error"):
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(operation["error"]),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": "Agent Engine deleted",
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ if self.timeout is not None and time.monotonic() - start_time
>= self.timeout:
+ yield TriggerEvent(
+ {
+ "status": "timeout",
+ "message": (
+ f"Timed out waiting for Agent Engine
{self.agent_engine_id} to be deleted"
+ ),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ self.log.info("Waiting for Agent Engine %s to be deleted.",
self.agent_engine_id)
+ await asyncio.sleep(self.poll_interval)
+ except Exception as err:
+ self.log.exception("Exception occurred while waiting for Agent
Engine deletion.")
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(err),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+
+
+class AgentEngineQueryJobTrigger(BaseTrigger):
+ """Trigger that waits until a Vertex AI Agent Engine query job
completes."""
+
+ def __init__(
+ self,
+ project_id: str,
+ location: str,
+ operation_name: str,
+ config: vertexai_types.CheckQueryJobAgentEngineConfigOrDict | None =
None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: float = 30,
+ timeout: float | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.location = location
+ self.operation_name = operation_name
+ self.config = config
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_interval = poll_interval
+ self.timeout = timeout
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineQueryJobTrigger",
+ {
+ "project_id": self.project_id,
+ "location": self.location,
+ "operation_name": self.operation_name,
+ "config": _serialize_value(self.config),
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_interval": self.poll_interval,
+ "timeout": self.timeout,
+ },
+ )
+
+ @cached_property
+ def async_hook(self) -> AgentEngineAsyncHook:
+ return AgentEngineAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ start_time = time.monotonic()
+ try:
+ while True:
+ query_job = await self.async_hook.check_query_agent_engine_job(
+ project_id=self.project_id,
+ location=self.location,
+ operation_name=self.operation_name,
+ config=self.config,
+ )
+ status = getattr(query_job, "status", None)
+ if status == "SUCCESS":
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": "Agent Engine query job completed",
+ "query_job": _serialize_value(query_job),
+ }
+ )
+ return
+ if status == "FAILED":
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Agent Engine query job
{self.operation_name} failed.",
+ "query_job": _serialize_value(query_job),
+ }
+ )
+ return
Review Comment:
Good catch. I updated both paths now:
- the deferrable trigger treats any status other than `None`/`RUNNING` as
terminal and returns an error event
- the synchronous `wait_for_query_agent_engine_job()` path now raises
immediately for unexpected statuses instead of warning and polling until timeout
I also added regression coverage for `CANCELLED` in both paths.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]