SameerMesiah97 commented on code in PR #68479:
URL: https://github.com/apache/airflow/pull/68479#discussion_r3444874580
##########
providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py:
##########
@@ -126,6 +130,213 @@ def _serialize_job(self, job: Any) -> Any:
return self.job_serializer_class.to_dict(job)
+class AgentEngineDeleteTrigger(BaseTrigger):
+ """Trigger that waits until a Vertex AI Agent Engine no longer exists."""
+
+ def __init__(
+ self,
+ project_id: str,
+ location: str,
+ agent_engine_id: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: float = 30,
+ timeout: float | None = None,
+ operation_name: str | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
Review Comment:
Where is `project_id` being used in this trigger? I dont think it is needed.
##########
providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py:
##########
@@ -126,6 +130,213 @@ def _serialize_job(self, job: Any) -> Any:
return self.job_serializer_class.to_dict(job)
+class AgentEngineDeleteTrigger(BaseTrigger):
+ """Trigger that waits until a Vertex AI Agent Engine no longer exists."""
+
+ def __init__(
+ self,
+ project_id: str,
+ location: str,
+ agent_engine_id: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: float = 30,
+ timeout: float | None = None,
+ operation_name: str | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.location = location
+ self.agent_engine_id = agent_engine_id
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_interval = poll_interval
+ self.timeout = timeout
+ self.operation_name = operation_name
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineDeleteTrigger",
+ {
+ "project_id": self.project_id,
+ "location": self.location,
+ "agent_engine_id": self.agent_engine_id,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_interval": self.poll_interval,
+ "timeout": self.timeout,
+ "operation_name": self.operation_name,
+ },
+ )
+
+ @cached_property
+ def async_hook(self) -> AgentEngineAsyncHook:
+ return AgentEngineAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ if not self.operation_name:
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": "Delete Agent Engine operation name is
required.",
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ start_time = time.monotonic()
+ try:
+ while True:
+ operation = await self.async_hook.get_agent_engine_operation(
+ location=self.location,
+ operation_name=self.operation_name,
+ )
+ if operation.get("done"):
+ if operation.get("error"):
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(operation["error"]),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": "Agent Engine deleted",
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ if self.timeout is not None and time.monotonic() - start_time
>= self.timeout:
+ yield TriggerEvent(
+ {
+ "status": "timeout",
+ "message": (
+ f"Timed out waiting for Agent Engine
{self.agent_engine_id} to be deleted"
+ ),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ self.log.info("Waiting for Agent Engine %s to be deleted.",
self.agent_engine_id)
+ await asyncio.sleep(self.poll_interval)
+ except Exception as err:
+ self.log.exception("Exception occurred while waiting for Agent
Engine deletion.")
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(err),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+
+
+class AgentEngineQueryJobTrigger(BaseTrigger):
+ """Trigger that waits until a Vertex AI Agent Engine query job
completes."""
+
+ def __init__(
+ self,
+ project_id: str,
+ location: str,
+ operation_name: str,
+ config: vertexai_types.CheckQueryJobAgentEngineConfigOrDict | None =
None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: float = 30,
+ timeout: float | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.location = location
+ self.operation_name = operation_name
+ self.config = config
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_interval = poll_interval
+ self.timeout = timeout
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineQueryJobTrigger",
+ {
+ "project_id": self.project_id,
+ "location": self.location,
+ "operation_name": self.operation_name,
+ "config": _serialize_value(self.config),
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_interval": self.poll_interval,
+ "timeout": self.timeout,
+ },
+ )
+
+ @cached_property
+ def async_hook(self) -> AgentEngineAsyncHook:
+ return AgentEngineAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ start_time = time.monotonic()
+ try:
+ while True:
+ query_job = await self.async_hook.check_query_agent_engine_job(
+ project_id=self.project_id,
+ location=self.location,
+ operation_name=self.operation_name,
+ config=self.config,
+ )
+ status = getattr(query_job, "status", None)
+ if status == "SUCCESS":
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": "Agent Engine query job completed",
+ "query_job": _serialize_value(query_job),
+ }
+ )
+ return
+ if status == "FAILED":
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Agent Engine query job
{self.operation_name} failed.",
+ "query_job": _serialize_value(query_job),
+ }
+ )
+ return
+
+ if self.timeout is not None and time.monotonic() - start_time
>= self.timeout:
+ yield TriggerEvent(
+ {
+ "status": "timeout",
+ "message": f"Timed out waiting for Agent Engine
query job {self.operation_name}",
+ "operation_name": self.operation_name,
Review Comment:
The shape of the `TriggerEvent` is different for the timeout and exception
events. Why not return the `query_job` for all? The key thing here is
consistency.
##########
providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py:
##########
@@ -0,0 +1,268 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.google.cloud.triggers.vertex_ai import (
+ AgentEngineDeleteTrigger,
+ AgentEngineQueryJobTrigger,
+)
+from airflow.triggers.base import TriggerEvent
+
+GCP_PROJECT = "test-project"
+GCP_LOCATION = "us-central1"
+GCP_CONN_ID = "test-conn"
+IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
+AGENT_ENGINE_ID = "123"
+OPERATION_NAME =
"projects/test-project/locations/us-central1/operations/delete-123"
+QUERY_OPERATION_NAME =
"projects/test-project/locations/us-central1/operations/query-123"
+CHECK_QUERY_CONFIG = {"retrieve_result": True}
+
+
+class FakeModel:
+ def __init__(self, payload):
+ self.payload = payload
+ for key, value in payload.items():
+ setattr(self, key, value)
+
+ def model_dump(self, mode="json"):
+ return self.payload
+
+
[email protected]
+def delete_trigger():
+ return AgentEngineDeleteTrigger(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ poll_interval=1,
+ timeout=60,
+ operation_name=OPERATION_NAME,
+ )
+
+
[email protected]
+def query_job_trigger():
+ return AgentEngineQueryJobTrigger(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ poll_interval=1,
+ timeout=60,
+ )
+
+
Review Comment:
I would add polling until success tests for both triggers. Similar theme as
the ones I suggested for the hook methods.
##########
providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py:
##########
@@ -126,6 +130,213 @@ def _serialize_job(self, job: Any) -> Any:
return self.job_serializer_class.to_dict(job)
+class AgentEngineDeleteTrigger(BaseTrigger):
+ """Trigger that waits until a Vertex AI Agent Engine no longer exists."""
+
+ def __init__(
+ self,
+ project_id: str,
+ location: str,
+ agent_engine_id: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: float = 30,
+ timeout: float | None = None,
+ operation_name: str | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.location = location
+ self.agent_engine_id = agent_engine_id
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_interval = poll_interval
+ self.timeout = timeout
+ self.operation_name = operation_name
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineDeleteTrigger",
+ {
+ "project_id": self.project_id,
+ "location": self.location,
+ "agent_engine_id": self.agent_engine_id,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_interval": self.poll_interval,
+ "timeout": self.timeout,
+ "operation_name": self.operation_name,
+ },
+ )
+
+ @cached_property
+ def async_hook(self) -> AgentEngineAsyncHook:
+ return AgentEngineAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ if not self.operation_name:
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": "Delete Agent Engine operation name is
required.",
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ start_time = time.monotonic()
+ try:
+ while True:
+ operation = await self.async_hook.get_agent_engine_operation(
+ location=self.location,
+ operation_name=self.operation_name,
+ )
+ if operation.get("done"):
+ if operation.get("error"):
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(operation["error"]),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": "Agent Engine deleted",
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ if self.timeout is not None and time.monotonic() - start_time
>= self.timeout:
+ yield TriggerEvent(
+ {
+ "status": "timeout",
+ "message": (
+ f"Timed out waiting for Agent Engine
{self.agent_engine_id} to be deleted"
+ ),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ self.log.info("Waiting for Agent Engine %s to be deleted.",
self.agent_engine_id)
+ await asyncio.sleep(self.poll_interval)
+ except Exception as err:
+ self.log.exception("Exception occurred while waiting for Agent
Engine deletion.")
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(err),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+
+
+class AgentEngineQueryJobTrigger(BaseTrigger):
+ """Trigger that waits until a Vertex AI Agent Engine query job
completes."""
+
+ def __init__(
+ self,
+ project_id: str,
+ location: str,
+ operation_name: str,
+ config: vertexai_types.CheckQueryJobAgentEngineConfigOrDict | None =
None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: float = 30,
+ timeout: float | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.location = location
+ self.operation_name = operation_name
+ self.config = config
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_interval = poll_interval
+ self.timeout = timeout
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineQueryJobTrigger",
+ {
+ "project_id": self.project_id,
+ "location": self.location,
+ "operation_name": self.operation_name,
+ "config": _serialize_value(self.config),
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_interval": self.poll_interval,
+ "timeout": self.timeout,
+ },
+ )
+
+ @cached_property
+ def async_hook(self) -> AgentEngineAsyncHook:
+ return AgentEngineAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ start_time = time.monotonic()
+ try:
+ while True:
+ query_job = await self.async_hook.check_query_agent_engine_job(
+ project_id=self.project_id,
+ location=self.location,
+ operation_name=self.operation_name,
+ config=self.config,
+ )
+ status = getattr(query_job, "status", None)
+ if status == "SUCCESS":
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": "Agent Engine query job completed",
+ "query_job": _serialize_value(query_job),
+ }
+ )
+ return
+ if status == "FAILED":
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Agent Engine query job
{self.operation_name} failed.",
+ "query_job": _serialize_value(query_job),
+ }
+ )
+ return
Review Comment:
Are "SUCCESS" and "FAILED" the only terminal states? I am not the domain
expert here but if you have other states like "CANCELLED" or "UNKNOWN", this
will poll forever if a timeout is not set. I think this is potentially a very
serious issue with this trigger.
##########
providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py:
##########
@@ -126,6 +130,213 @@ def _serialize_job(self, job: Any) -> Any:
return self.job_serializer_class.to_dict(job)
+class AgentEngineDeleteTrigger(BaseTrigger):
+ """Trigger that waits until a Vertex AI Agent Engine no longer exists."""
+
+ def __init__(
+ self,
+ project_id: str,
+ location: str,
+ agent_engine_id: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: float = 30,
+ timeout: float | None = None,
+ operation_name: str | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.location = location
+ self.agent_engine_id = agent_engine_id
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_interval = poll_interval
+ self.timeout = timeout
+ self.operation_name = operation_name
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineDeleteTrigger",
+ {
+ "project_id": self.project_id,
+ "location": self.location,
+ "agent_engine_id": self.agent_engine_id,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_interval": self.poll_interval,
+ "timeout": self.timeout,
+ "operation_name": self.operation_name,
+ },
+ )
+
+ @cached_property
+ def async_hook(self) -> AgentEngineAsyncHook:
+ return AgentEngineAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ if not self.operation_name:
Review Comment:
I am not sure why this is needed here when the `operation_name` is already
being validated in the operator.
##########
providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py:
##########
@@ -0,0 +1,260 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import
AgentEngineHook
+
+from unit.google.cloud.utils.base_gcp_mock import
mock_base_gcp_hook_default_project_id
+
+BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
+AGENT_ENGINE_STRING =
"airflow.providers.google.cloud.hooks.vertex_ai.agent_engine.{}"
+
+TEST_GCP_CONN_ID = "test-gcp-conn-id"
+GCP_PROJECT = "test-project"
+GCP_LOCATION = "us-central1"
+AGENT_ENGINE_ID = "123"
+AGENT_ENGINE_NAME =
"projects/test-project/locations/us-central1/reasoningEngines/123"
+OPERATION_NAME =
"projects/test-project/locations/us-central1/operations/delete-123"
+QUERY_OPERATION_NAME =
"projects/test-project/locations/us-central1/operations/query-123"
+CONFIG = {"display_name": "test-agent-engine"}
+QUERY_CONFIG = {"query": "hello", "output_gcs_uri":
"gs://test-bucket/query-output/"}
+CHECK_QUERY_CONFIG = {"retrieve_result": True}
+
+
+class TestAgentEngineHookWithDefaultProjectId:
+ def setup_method(self):
+ with mock.patch(
+ BASE_STRING.format("GoogleBaseHook.__init__"),
new=mock_base_gcp_hook_default_project_id
+ ):
+ self.hook = AgentEngineHook(gcp_conn_id=TEST_GCP_CONN_ID)
+
+ @mock.patch(AGENT_ENGINE_STRING.format("Client"), autospec=True)
+ def test_get_agent_engine_client(self, mock_client):
+ self.hook.get_credentials =
mock.Mock(return_value=mock.sentinel.credentials, spec=())
+
+ result = self.hook.get_agent_engine_client(project_id=GCP_PROJECT,
location=GCP_LOCATION)
+
+ mock_client.assert_called_once_with(
+ project=GCP_PROJECT,
+ location=GCP_LOCATION,
+ credentials=mock.sentinel.credentials,
+ )
+ assert result == mock_client.return_value.agent_engines
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"),
autospec=True)
+ def test_create_agent_engine(self, mock_get_client):
+ result = self.hook.create_agent_engine(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ config=CONFIG,
+ )
+
+ mock_get_client.assert_called_once_with(self.hook,
project_id=GCP_PROJECT, location=GCP_LOCATION)
+ mock_get_client.return_value.create.assert_called_once_with(
+ agent=None,
+ config=CONFIG,
+ )
+ assert result == mock_get_client.return_value.create.return_value
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"),
autospec=True)
+ def test_get_agent_engine(self, mock_get_client):
+ result = self.hook.get_agent_engine(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ )
+
+
mock_get_client.return_value.get.assert_called_once_with(name=AGENT_ENGINE_NAME)
+ assert result == mock_get_client.return_value.get.return_value
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"),
autospec=True)
+ def test_query_agent_engine(self, mock_get_client):
+ result = self.hook.query_agent_engine(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ config=QUERY_CONFIG,
+ )
+
+ mock_get_client.return_value.run_query_job.assert_called_once_with(
+ name=AGENT_ENGINE_NAME,
+ config=QUERY_CONFIG,
+ )
+ assert result ==
mock_get_client.return_value.run_query_job.return_value
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"),
autospec=True)
+ def test_check_query_agent_engine_job(self, mock_get_client):
+ result = self.hook.check_query_agent_engine_job(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ )
+
+ mock_get_client.return_value.check_query_job.assert_called_once_with(
+ name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ )
+ assert result ==
mock_get_client.return_value.check_query_job.return_value
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"),
autospec=True)
+ def test_wait_for_query_agent_engine_job_returns_when_successful(self,
mock_check_query_job):
+ mock_check_query_job.return_value.status = "SUCCESS"
+
+ result = self.hook.wait_for_query_agent_engine_job(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ )
+
+ mock_check_query_job.assert_called_once_with(
+ self.hook,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ )
+ assert result == mock_check_query_job.return_value
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"),
autospec=True)
+ def test_wait_for_query_agent_engine_job_raises_on_failed_status(self,
mock_check_query_job):
+ mock_check_query_job.return_value.status = "FAILED"
+
+ with pytest.raises(RuntimeError, match="Agent Engine query job .*
failed"):
+ self.hook.wait_for_query_agent_engine_job(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ )
+
+ @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True)
+ @mock.patch(AGENT_ENGINE_STRING.format("time.monotonic"), autospec=True)
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"),
autospec=True)
+ def test_wait_for_query_agent_engine_job_times_out(
+ self, mock_check_query_job, mock_monotonic, mock_sleep
+ ):
+ mock_check_query_job.return_value.status = "RUNNING"
+ mock_monotonic.side_effect = [1, 3]
+
+ with pytest.raises(TimeoutError, match="Timed out waiting for Agent
Engine query job"):
+ self.hook.wait_for_query_agent_engine_job(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ timeout=1,
+ )
+
+ mock_sleep.assert_not_called()
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"),
autospec=True)
+ def test_update_agent_engine(self, mock_get_client):
+ result = self.hook.update_agent_engine(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ config=CONFIG,
+ )
+
+ mock_get_client.return_value.update.assert_called_once_with(
+ name=AGENT_ENGINE_NAME,
+ agent=None,
+ config=CONFIG,
+ )
+ assert result == mock_get_client.return_value.update.return_value
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"),
autospec=True)
+ def test_delete_agent_engine(self, mock_get_client):
+ result = self.hook.delete_agent_engine(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ force=True,
+ config=CONFIG,
+ )
+
+ mock_get_client.return_value.delete.assert_called_once_with(
+ name=AGENT_ENGINE_NAME,
+ force=True,
+ config=CONFIG,
+ )
+ assert result == mock_get_client.return_value.delete.return_value
+
+
@mock.patch(AGENT_ENGINE_STRING.format("google.auth.transport.requests.AuthorizedSession"),
autospec=True)
+ def test_get_agent_engine_operation(self, mock_session):
+ self.hook.get_credentials =
mock.Mock(return_value=mock.sentinel.credentials, spec=())
+ mock_session.return_value.get.return_value.json.return_value =
{"name": OPERATION_NAME, "done": True}
+
+ result = self.hook.get_agent_engine_operation(
+ location=GCP_LOCATION,
+ operation_name=OPERATION_NAME,
+ )
+
+ mock_session.assert_called_once_with(mock.sentinel.credentials)
+ mock_session.return_value.get.assert_called_once_with(
+
f"https://{GCP_LOCATION}-aiplatform.googleapis.com/v1beta1/{OPERATION_NAME}"
+ )
+
mock_session.return_value.get.return_value.raise_for_status.assert_called_once_with()
+ assert result == {"name": OPERATION_NAME, "done": True}
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_operation"),
autospec=True)
+ def test_wait_for_agent_engine_operation_returns_when_done(self,
mock_get_operation):
+ mock_get_operation.return_value = {"name": OPERATION_NAME, "done":
True}
+
+ self.hook.wait_for_agent_engine_operation(
+ location=GCP_LOCATION,
+ operation_name=OPERATION_NAME,
+ )
+
+ mock_get_operation.assert_called_once_with(
+ self.hook,
+ location=GCP_LOCATION,
+ operation_name=OPERATION_NAME,
+ )
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_operation"),
autospec=True)
+ def test_wait_for_agent_engine_operation_raises_on_error(self,
mock_get_operation):
+ mock_get_operation.return_value = {"name": OPERATION_NAME, "done":
True, "error": {"message": "boom"}}
+
+ with pytest.raises(RuntimeError, match="Agent Engine operation .*
failed"):
+ self.hook.wait_for_agent_engine_operation(
+ location=GCP_LOCATION,
+ operation_name=OPERATION_NAME,
+ )
+
+ @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True)
+ @mock.patch(AGENT_ENGINE_STRING.format("time.monotonic"), autospec=True)
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_operation"),
autospec=True)
+ def test_wait_for_agent_engine_operation_times_out(self,
mock_get_operation, mock_monotonic, mock_sleep):
+ mock_get_operation.return_value = {"name": OPERATION_NAME, "done":
False}
+ mock_monotonic.side_effect = [1, 3]
+
+ with pytest.raises(TimeoutError, match="Timed out waiting for Agent
Engine operation"):
+ self.hook.wait_for_agent_engine_operation(
+ location=GCP_LOCATION,
+ operation_name=OPERATION_NAME,
+ timeout=1,
+ )
+
+ mock_sleep.assert_not_called()
Review Comment:
I would add a polling test for `wait_for_agent_engine_operation` as well
following a similar structure to the one I suggested for
`wait_for_query_agent_engine_job`.
##########
providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py:
##########
@@ -0,0 +1,458 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.common.compat.sdk import TaskDeferred
+from airflow.providers.google.cloud.operators.vertex_ai.agent_engine import (
+ CheckQueryAgentEngineOperator,
+ CreateAgentEngineOperator,
+ DeleteAgentEngineOperator,
+ GetAgentEngineOperator,
+ QueryAgentEngineOperator,
+ UpdateAgentEngineOperator,
+)
+
+AGENT_ENGINE_PATH =
"airflow.providers.google.cloud.operators.vertex_ai.agent_engine.{}"
+
+TASK_ID = "test_task_id"
+GCP_PROJECT = "test-project"
+GCP_LOCATION = "us-central1"
+GCP_CONN_ID = "test-conn"
+IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
+AGENT_ENGINE_ID = "123"
+AGENT_ENGINE_NAME =
"projects/test-project/locations/us-central1/reasoningEngines/123"
+CONFIG = {"display_name": "test-agent-engine"}
+QUERY_CONFIG = {"query": "hello", "output_gcs_uri":
"gs://test-bucket/query-output/"}
+CHECK_QUERY_CONFIG = {"retrieve_result": True}
+OPERATION = {"name": "operations/delete-123", "done": False}
+QUERY_OPERATION_NAME = "operations/query-123"
+
+
+class FakeModel:
+ def __init__(self, payload):
+ self.payload = payload
+ for key, value in payload.items():
+ setattr(self, key, value)
+
+ def model_dump(self, mode="json"):
+ return self.payload
+
+
+class FakeAgentEngine:
+ def __init__(self, payload):
+ self.api_resource = FakeModel(payload)
+
+
[email protected]
+def context():
+ return {"ti": mock.Mock(spec_set=["xcom_push"])}
+
+
+def assert_hook_created(mock_hook):
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+
+class TestCreateAgentEngineOperator:
+ @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True)
+ def test_execute(self, mock_hook, context):
+ mock_hook.return_value.create_agent_engine.return_value =
FakeAgentEngine(
+ {"name": AGENT_ENGINE_NAME, "display_name": "test-agent-engine"}
+ )
+ op = CreateAgentEngineOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ config=CONFIG,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ result = op.execute(context=context)
+
+ assert_hook_created(mock_hook)
+ mock_hook.return_value.create_agent_engine.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent=None,
+ config=CONFIG,
+ )
+ assert result == {"name": AGENT_ENGINE_NAME, "display_name":
"test-agent-engine"}
+
+
+class TestGetAgentEngineOperator:
+ @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True)
+ def test_execute(self, mock_hook, context):
+ mock_hook.return_value.get_agent_engine.return_value =
FakeAgentEngine({"name": AGENT_ENGINE_NAME})
+ op = GetAgentEngineOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ result = op.execute(context=context)
+
+ mock_hook.return_value.get_agent_engine.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ )
+ assert result == {"name": AGENT_ENGINE_NAME}
+
+
+class TestQueryAgentEngineOperator:
+ @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True)
+ def test_execute(self, mock_hook, context):
+ result_payload = {
+ "job_name": "operations/query-123",
+ "input_gcs_uri": "gs://test-bucket/query-output/input.json",
+ "output_gcs_uri": "gs://test-bucket/query-output/output.json",
+ }
+ mock_hook.return_value.query_agent_engine.return_value =
FakeModel(result_payload)
+ op = QueryAgentEngineOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ config=QUERY_CONFIG,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ result = op.execute(context=context)
+
+ mock_hook.return_value.query_agent_engine.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ config=QUERY_CONFIG,
+ )
+ assert result == result_payload
+
+
+class TestCheckQueryAgentEngineOperator:
+ @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True)
+ def test_execute(self, mock_hook, context):
+ result_payload = {
+ "operation_name": QUERY_OPERATION_NAME,
+ "output_gcs_uri": "gs://test-bucket/query-output/output.json",
+ "status": "SUCCESS",
+ "result": "done",
+ }
+ mock_hook.return_value.wait_for_query_agent_engine_job.return_value =
FakeModel(result_payload)
+ op = CheckQueryAgentEngineOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ poll_interval=1,
+ timeout=60,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ result = op.execute(context=context)
+
+
mock_hook.return_value.wait_for_query_agent_engine_job.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ poll_interval=1,
+ timeout=60,
+ )
+ assert result == result_payload
+
+ @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineQueryJobTrigger"),
autospec=True)
+ @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True)
+ def test_execute_deferrable(self, mock_hook, mock_trigger, context):
+ op = CheckQueryAgentEngineOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ poll_interval=1,
+ timeout=60,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred):
+ op.execute(context=context)
+
+
mock_hook.return_value.wait_for_query_agent_engine_job.assert_not_called()
+ mock_trigger.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ poll_interval=1,
+ timeout=60,
+ )
+
+ def test_execute_complete_success(self, context):
+ op = CheckQueryAgentEngineOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ )
+ query_job = {"operation_name": QUERY_OPERATION_NAME, "status":
"SUCCESS"}
+
+ result = op.execute_complete(
+ context=context,
+ event={"status": "success", "message": "done", "query_job":
query_job},
+ )
+
+ assert result == query_job
+
+ def test_execute_complete_error(self, context):
+ op = CheckQueryAgentEngineOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ )
+
+ with pytest.raises(RuntimeError, match="boom"):
+ op.execute_complete(context=context, event={"status": "error",
"message": "boom"})
+
+ def test_execute_complete_timeout(self, context):
+ op = CheckQueryAgentEngineOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ )
+
+ with pytest.raises(TimeoutError, match="timed out"):
+ op.execute_complete(context=context, event={"status": "timeout",
"message": "timed out"})
+
+ def test_execute_complete_without_event(self, context):
+ op = CheckQueryAgentEngineOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ )
+
+ with pytest.raises(RuntimeError, match="No event received in trigger
callback"):
+ op.execute_complete(context=context)
+
+
+class TestUpdateAgentEngineOperator:
+ @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True)
+ def test_execute(self, mock_hook, context):
+ mock_hook.return_value.update_agent_engine.return_value =
FakeAgentEngine(
+ {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"}
+ )
+ op = UpdateAgentEngineOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ config=CONFIG,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ result = op.execute(context=context)
+
+ mock_hook.return_value.update_agent_engine.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ agent=None,
+ config=CONFIG,
+ )
+ assert result == {"name": AGENT_ENGINE_NAME, "display_name":
"updated-agent-engine"}
+
+
+class TestDeleteAgentEngineOperator:
+ @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True)
+ def test_execute_without_wait(self, mock_hook, context):
+ mock_hook.return_value.delete_agent_engine.return_value =
FakeModel(OPERATION)
+ op = DeleteAgentEngineOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ force=True,
+ config=CONFIG,
+ wait_for_completion=False,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ result = op.execute(context=context)
+
+ mock_hook.return_value.delete_agent_engine.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ force=True,
+ config=CONFIG,
+ )
+
mock_hook.return_value.wait_for_agent_engine_operation.assert_not_called()
+ assert result == OPERATION
+
+ @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True)
+ def test_execute_waits_until_deleted(self, mock_hook, context):
+ mock_hook.return_value.delete_agent_engine.return_value =
FakeModel(OPERATION)
+ op = DeleteAgentEngineOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ wait_for_completion=True,
+ poll_interval=1,
+ timeout=60,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ result = op.execute(context=context)
+
+
mock_hook.return_value.wait_for_agent_engine_operation.assert_called_once_with(
+ location=GCP_LOCATION,
+ operation_name=OPERATION["name"],
+ poll_interval=1,
+ timeout=60,
+ )
+ assert result == OPERATION
+
+ @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True)
+ def test_execute_does_not_wait_when_delete_operation_is_done(self,
mock_hook, context):
+ operation = {"name": "operations/delete-123", "done": True}
+ mock_hook.return_value.delete_agent_engine.return_value =
FakeModel(operation)
+ op = DeleteAgentEngineOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ wait_for_completion=True,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ result = op.execute(context=context)
+
+
mock_hook.return_value.wait_for_agent_engine_operation.assert_not_called()
+ assert result == operation
+
Review Comment:
Add a test covering the Delete Agent Engine operation did not include an
operation name error path.
##########
providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py:
##########
@@ -126,6 +130,213 @@ def _serialize_job(self, job: Any) -> Any:
return self.job_serializer_class.to_dict(job)
+class AgentEngineDeleteTrigger(BaseTrigger):
+ """Trigger that waits until a Vertex AI Agent Engine no longer exists."""
+
+ def __init__(
+ self,
+ project_id: str,
+ location: str,
+ agent_engine_id: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: float = 30,
+ timeout: float | None = None,
+ operation_name: str | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.location = location
+ self.agent_engine_id = agent_engine_id
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_interval = poll_interval
+ self.timeout = timeout
+ self.operation_name = operation_name
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineDeleteTrigger",
+ {
+ "project_id": self.project_id,
+ "location": self.location,
+ "agent_engine_id": self.agent_engine_id,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_interval": self.poll_interval,
+ "timeout": self.timeout,
+ "operation_name": self.operation_name,
+ },
+ )
+
+ @cached_property
+ def async_hook(self) -> AgentEngineAsyncHook:
+ return AgentEngineAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ if not self.operation_name:
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": "Delete Agent Engine operation name is
required.",
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ start_time = time.monotonic()
+ try:
+ while True:
+ operation = await self.async_hook.get_agent_engine_operation(
+ location=self.location,
+ operation_name=self.operation_name,
+ )
+ if operation.get("done"):
+ if operation.get("error"):
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(operation["error"]),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": "Agent Engine deleted",
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ if self.timeout is not None and time.monotonic() - start_time
>= self.timeout:
+ yield TriggerEvent(
+ {
+ "status": "timeout",
+ "message": (
+ f"Timed out waiting for Agent Engine
{self.agent_engine_id} to be deleted"
+ ),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+ return
+
+ self.log.info("Waiting for Agent Engine %s to be deleted.",
self.agent_engine_id)
+ await asyncio.sleep(self.poll_interval)
+ except Exception as err:
+ self.log.exception("Exception occurred while waiting for Agent
Engine deletion.")
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(err),
+ "agent_engine_id": self.agent_engine_id,
+ }
+ )
+
+
+class AgentEngineQueryJobTrigger(BaseTrigger):
+ """Trigger that waits until a Vertex AI Agent Engine query job
completes."""
+
+ def __init__(
+ self,
+ project_id: str,
+ location: str,
+ operation_name: str,
+ config: vertexai_types.CheckQueryJobAgentEngineConfigOrDict | None =
None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: float = 30,
+ timeout: float | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.location = location
+ self.operation_name = operation_name
+ self.config = config
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_interval = poll_interval
+ self.timeout = timeout
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineQueryJobTrigger",
+ {
+ "project_id": self.project_id,
+ "location": self.location,
+ "operation_name": self.operation_name,
+ "config": _serialize_value(self.config),
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_interval": self.poll_interval,
+ "timeout": self.timeout,
+ },
+ )
+
+ @cached_property
+ def async_hook(self) -> AgentEngineAsyncHook:
+ return AgentEngineAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ start_time = time.monotonic()
+ try:
+ while True:
+ query_job = await self.async_hook.check_query_agent_engine_job(
+ project_id=self.project_id,
+ location=self.location,
+ operation_name=self.operation_name,
+ config=self.config,
+ )
+ status = getattr(query_job, "status", None)
+ if status == "SUCCESS":
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": "Agent Engine query job completed",
+ "query_job": _serialize_value(query_job),
+ }
+ )
+ return
+ if status == "FAILED":
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Agent Engine query job
{self.operation_name} failed.",
+ "query_job": _serialize_value(query_job),
+ }
+ )
+ return
+
+ if self.timeout is not None and time.monotonic() - start_time
>= self.timeout:
+ yield TriggerEvent(
+ {
+ "status": "timeout",
+ "message": f"Timed out waiting for Agent Engine
query job {self.operation_name}",
+ "operation_name": self.operation_name,
+ }
+ )
+ return
+
+ self.log.info("Waiting for Agent Engine query job %s to
complete.", self.operation_name)
+ await asyncio.sleep(self.poll_interval)
+ except Exception as err:
+ self.log.exception("Exception occurred while waiting for Agent
Engine query job.")
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(err),
Review Comment:
Since you are using broad Exception handling, it would be best to provide
more context like this:
`"message" : f"Failed while polling Agent Engine query job: {err}"`
##########
providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py:
##########
@@ -0,0 +1,260 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import
AgentEngineHook
+
+from unit.google.cloud.utils.base_gcp_mock import
mock_base_gcp_hook_default_project_id
+
+BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
+AGENT_ENGINE_STRING =
"airflow.providers.google.cloud.hooks.vertex_ai.agent_engine.{}"
+
+TEST_GCP_CONN_ID = "test-gcp-conn-id"
+GCP_PROJECT = "test-project"
+GCP_LOCATION = "us-central1"
+AGENT_ENGINE_ID = "123"
+AGENT_ENGINE_NAME =
"projects/test-project/locations/us-central1/reasoningEngines/123"
+OPERATION_NAME =
"projects/test-project/locations/us-central1/operations/delete-123"
+QUERY_OPERATION_NAME =
"projects/test-project/locations/us-central1/operations/query-123"
+CONFIG = {"display_name": "test-agent-engine"}
+QUERY_CONFIG = {"query": "hello", "output_gcs_uri":
"gs://test-bucket/query-output/"}
+CHECK_QUERY_CONFIG = {"retrieve_result": True}
+
+
+class TestAgentEngineHookWithDefaultProjectId:
Review Comment:
I think you should add tests for the async hook as well. They dont have to
be exhaustive. They just have to assert that that underling sync methods are
being called by the async hook methods.
##########
providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py:
##########
@@ -0,0 +1,260 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import
AgentEngineHook
+
+from unit.google.cloud.utils.base_gcp_mock import
mock_base_gcp_hook_default_project_id
+
+BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
+AGENT_ENGINE_STRING =
"airflow.providers.google.cloud.hooks.vertex_ai.agent_engine.{}"
+
+TEST_GCP_CONN_ID = "test-gcp-conn-id"
+GCP_PROJECT = "test-project"
+GCP_LOCATION = "us-central1"
+AGENT_ENGINE_ID = "123"
+AGENT_ENGINE_NAME =
"projects/test-project/locations/us-central1/reasoningEngines/123"
+OPERATION_NAME =
"projects/test-project/locations/us-central1/operations/delete-123"
+QUERY_OPERATION_NAME =
"projects/test-project/locations/us-central1/operations/query-123"
+CONFIG = {"display_name": "test-agent-engine"}
+QUERY_CONFIG = {"query": "hello", "output_gcs_uri":
"gs://test-bucket/query-output/"}
+CHECK_QUERY_CONFIG = {"retrieve_result": True}
+
+
+class TestAgentEngineHookWithDefaultProjectId:
+ def setup_method(self):
+ with mock.patch(
+ BASE_STRING.format("GoogleBaseHook.__init__"),
new=mock_base_gcp_hook_default_project_id
+ ):
+ self.hook = AgentEngineHook(gcp_conn_id=TEST_GCP_CONN_ID)
+
+ @mock.patch(AGENT_ENGINE_STRING.format("Client"), autospec=True)
+ def test_get_agent_engine_client(self, mock_client):
+ self.hook.get_credentials =
mock.Mock(return_value=mock.sentinel.credentials, spec=())
+
+ result = self.hook.get_agent_engine_client(project_id=GCP_PROJECT,
location=GCP_LOCATION)
+
+ mock_client.assert_called_once_with(
+ project=GCP_PROJECT,
+ location=GCP_LOCATION,
+ credentials=mock.sentinel.credentials,
+ )
+ assert result == mock_client.return_value.agent_engines
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"),
autospec=True)
+ def test_create_agent_engine(self, mock_get_client):
+ result = self.hook.create_agent_engine(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ config=CONFIG,
+ )
+
+ mock_get_client.assert_called_once_with(self.hook,
project_id=GCP_PROJECT, location=GCP_LOCATION)
+ mock_get_client.return_value.create.assert_called_once_with(
+ agent=None,
+ config=CONFIG,
+ )
+ assert result == mock_get_client.return_value.create.return_value
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"),
autospec=True)
+ def test_get_agent_engine(self, mock_get_client):
+ result = self.hook.get_agent_engine(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ )
+
+
mock_get_client.return_value.get.assert_called_once_with(name=AGENT_ENGINE_NAME)
+ assert result == mock_get_client.return_value.get.return_value
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"),
autospec=True)
+ def test_query_agent_engine(self, mock_get_client):
+ result = self.hook.query_agent_engine(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ agent_engine_id=AGENT_ENGINE_ID,
+ config=QUERY_CONFIG,
+ )
+
+ mock_get_client.return_value.run_query_job.assert_called_once_with(
+ name=AGENT_ENGINE_NAME,
+ config=QUERY_CONFIG,
+ )
+ assert result ==
mock_get_client.return_value.run_query_job.return_value
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"),
autospec=True)
+ def test_check_query_agent_engine_job(self, mock_get_client):
+ result = self.hook.check_query_agent_engine_job(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ )
+
+ mock_get_client.return_value.check_query_job.assert_called_once_with(
+ name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ )
+ assert result ==
mock_get_client.return_value.check_query_job.return_value
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"),
autospec=True)
+ def test_wait_for_query_agent_engine_job_returns_when_successful(self,
mock_check_query_job):
+ mock_check_query_job.return_value.status = "SUCCESS"
+
+ result = self.hook.wait_for_query_agent_engine_job(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ )
+
+ mock_check_query_job.assert_called_once_with(
+ self.hook,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ )
+ assert result == mock_check_query_job.return_value
+
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"),
autospec=True)
+ def test_wait_for_query_agent_engine_job_raises_on_failed_status(self,
mock_check_query_job):
+ mock_check_query_job.return_value.status = "FAILED"
+
+ with pytest.raises(RuntimeError, match="Agent Engine query job .*
failed"):
+ self.hook.wait_for_query_agent_engine_job(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ )
+
+ @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True)
+ @mock.patch(AGENT_ENGINE_STRING.format("time.monotonic"), autospec=True)
+
@mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"),
autospec=True)
+ def test_wait_for_query_agent_engine_job_times_out(
+ self, mock_check_query_job, mock_monotonic, mock_sleep
+ ):
+ mock_check_query_job.return_value.status = "RUNNING"
+ mock_monotonic.side_effect = [1, 3]
+
+ with pytest.raises(TimeoutError, match="Timed out waiting for Agent
Engine query job"):
+ self.hook.wait_for_query_agent_engine_job(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ operation_name=QUERY_OPERATION_NAME,
+ config=CHECK_QUERY_CONFIG,
+ timeout=1,
+ )
+
+ mock_sleep.assert_not_called()
+
Review Comment:
I would add another test to cover the polling behaviour. Below is a
suggested structure you can follow:
```
@mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True)
@mock.patch(
AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"),
autospec=True,
)
def test_wait_for_query_agent_engine_job_polls_until_success(
self,
mock_check_query_job,
mock_sleep,
):
running_job = mock.Mock(status="RUNNING")
success_job = mock.Mock(status="SUCCESS")
mock_check_query_job.side_effect = [
running_job,
running_job,
success_job,
]
result = self.hook.wait_for_query_agent_engine_job(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
operation_name=QUERY_OPERATION_NAME,
config=CHECK_QUERY_CONFIG,
poll_interval=10,
)
assert result is success_job
assert mock_check_query_job.call_count == 3
assert mock_sleep.call_count == 2
mock_sleep.assert_called_with(10)
```
Keep in mind you will have to test and validate it yourself.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]