Lee-W commented on code in PR #37969:
URL: https://github.com/apache/airflow/pull/37969#discussion_r1526769734
##########
airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py:
##########
@@ -243,15 +245,110 @@ def run_pipeline_job(
location=region,
failure_policy=failure_policy,
)
+ self._pipeline_job.submit(
+ service_account=service_account,
+ network=network,
+ create_request_timeout=create_request_timeout,
+ experiment=experiment,
+ )
+ self._pipeline_job.wait()
+
+ return self._pipeline_job
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def submit_pipeline_job(
+ self,
+ project_id: str,
+ region: str,
+ display_name: str,
+ template_path: str,
+ job_id: str | None = None,
+ pipeline_root: str | None = None,
+ parameter_values: dict[str, Any] | None = None,
+ input_artifacts: dict[str, str] | None = None,
+ enable_caching: bool | None = None,
+ encryption_spec_key_name: str | None = None,
+ labels: dict[str, str] | None = None,
+ failure_policy: str | None = None,
+ # START: run param
+ service_account: str | None = None,
+ network: str | None = None,
+ create_request_timeout: float | None = None,
+ experiment: str | experiment_resources.Experiment | None = None,
+ # END: run param
+ ) -> PipelineJob:
+ """
+ Create and start a PipelineJob run.
+ For more info about the client method please see:
+
https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.PipelineJob#google_cloud_aiplatform_PipelineJob_submit
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
Review Comment:
Not sure whether we should keep `Required`, `Optional` here as we've already
specify them in type annotation
##########
airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py:
##########
@@ -407,3 +504,123 @@ def delete_pipeline_job(
metadata=metadata,
)
return result
+
+ @staticmethod
+ def extract_pipeline_job_id(obj: dict) -> str:
+ """Return unique id of a pipeline job from its name."""
+ return obj["name"].rpartition("/")[-1]
+
+
+class PipelineJobAsyncHook(GoogleBaseAsyncHook):
+ """Asynchronous hook for Google Cloud Vertex AI Pipeline Job APIs."""
+
+ sync_hook_class = PipelineJobHook
+ PIPELINE_COMPLETE_STATES = (
+ PipelineState.PIPELINE_STATE_CANCELLED,
+ PipelineState.PIPELINE_STATE_FAILED,
+ PipelineState.PIPELINE_STATE_PAUSED,
+ PipelineState.PIPELINE_STATE_SUCCEEDED,
+ )
+
+ def __init__(
+ self,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ gcp_conn_id=gcp_conn_id,
+ impersonation_chain=impersonation_chain,
+ **kwargs,
+ )
+
+ async def get_credentials(self) -> Credentials:
+ return (await self.get_sync_hook()).get_credentials()
+
+ async def get_project_id(self) -> str:
+ sync_hook = await self.get_sync_hook()
+ return sync_hook.project_id
+
+ async def get_location(self) -> str:
+ sync_hook = await self.get_sync_hook()
+ return sync_hook.location
+
+ async def get_pipeline_service_client(
+ self,
+ region: str | None = None,
+ ) -> PipelineServiceAsyncClient:
+ if region and region != "global":
+ client_options =
ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
+ else:
+ client_options = ClientOptions()
+ return PipelineServiceAsyncClient(
+ credentials=await self.get_credentials(),
+ client_info=CLIENT_INFO,
+ client_options=client_options,
+ )
+
+ async def get_pipeline_job(
+ self,
+ project_id: str,
+ location: str,
+ job_id: str,
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
+ timeout: float | _MethodDefault | None = DEFAULT,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> types.PipelineJob:
+ """
+ Get a PipelineJob proto message from PipelineServiceAsyncClient.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param job_id: Required. The ID of the PipelineJob resource.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = await self.get_pipeline_service_client(region=location)
+ pipeline_job_name = client.pipeline_job_path(
+ project=project_id,
+ location=location,
+ pipeline_job=job_id,
+ )
+ response: types.PipelineJob = await client.get_pipeline_job(
+ request={"name": pipeline_job_name},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return response
Review Comment:
```suggestion
pipeline_job: types.PipelineJob = await client.get_pipeline_job(
request={"name": pipeline_job_name},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return pipeline_job
```
##########
airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py:
##########
@@ -173,20 +181,46 @@ def execute(self, context: Context):
create_request_timeout=self.create_request_timeout,
experiment=self.experiment,
)
-
- pipeline_job = result.to_dict()
- pipeline_job_id = self.hook.extract_pipeline_job_id(pipeline_job)
+ pipeline_job_id = pipeline_job_obj.job_id
self.log.info("Pipeline job was created. Job id: %s", pipeline_job_id)
-
self.xcom_push(context, key="pipeline_job_id", value=pipeline_job_id)
VertexAIPipelineJobLink.persist(context=context, task_instance=self,
pipeline_id=pipeline_job_id)
+
+ if self.deferrable:
+ pipeline_job_obj.wait_for_resource_creation()
+ self.defer(
+ trigger=RunPipelineJobTrigger(
+ conn_id=self.gcp_conn_id,
+ project_id=self.project_id,
+ location=pipeline_job_obj.location,
+ job_id=pipeline_job_id,
+ poll_interval=self.poll_interval,
+ impersonation_chain=self.impersonation_chain,
+ ),
+ method_name="execute_complete",
+ )
+
+ pipeline_job_obj.wait()
+ pipeline_job = pipeline_job_obj.to_dict()
return pipeline_job
+ def execute_complete(self, context: Context, event: dict[str, Any]) ->
None:
Review Comment:
It seems the type is wrongly annotated
##########
tests/providers/google/cloud/triggers/test_vertex_ai.py:
##########
@@ -46,6 +55,44 @@
BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
+# PYTEST FIXTURES
Review Comment:
```suggestion
```
Not sure whether we really need this comment as they're obviously pytest
fixture an no other module seem to follow this convention
##########
airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py:
##########
@@ -243,15 +245,110 @@ def run_pipeline_job(
location=region,
failure_policy=failure_policy,
)
+ self._pipeline_job.submit(
+ service_account=service_account,
+ network=network,
+ create_request_timeout=create_request_timeout,
+ experiment=experiment,
+ )
+ self._pipeline_job.wait()
+
+ return self._pipeline_job
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def submit_pipeline_job(
+ self,
+ project_id: str,
+ region: str,
+ display_name: str,
+ template_path: str,
+ job_id: str | None = None,
+ pipeline_root: str | None = None,
+ parameter_values: dict[str, Any] | None = None,
+ input_artifacts: dict[str, str] | None = None,
+ enable_caching: bool | None = None,
+ encryption_spec_key_name: str | None = None,
+ labels: dict[str, str] | None = None,
+ failure_policy: str | None = None,
+ # START: run param
Review Comment:
not sure whether it makes more sense to make them keword only args
```suggestion
*,
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]