Lee-W commented on code in PR #37969:
URL: https://github.com/apache/airflow/pull/37969#discussion_r1535091325
##########
airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py:
##########
@@ -408,3 +505,123 @@ def delete_pipeline_job(
metadata=metadata,
)
return result
+
+ @staticmethod
+ def extract_pipeline_job_id(obj: dict) -> str:
+ """Return unique id of a pipeline job from its name."""
+ return obj["name"].rpartition("/")[-1]
+
+
+class PipelineJobAsyncHook(GoogleBaseAsyncHook):
+ """Asynchronous hook for Google Cloud Vertex AI Pipeline Job APIs."""
+
+ sync_hook_class = PipelineJobHook
+ PIPELINE_COMPLETE_STATES = (
+ PipelineState.PIPELINE_STATE_CANCELLED,
+ PipelineState.PIPELINE_STATE_FAILED,
+ PipelineState.PIPELINE_STATE_PAUSED,
+ PipelineState.PIPELINE_STATE_SUCCEEDED,
+ )
+
+ def __init__(
+ self,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ gcp_conn_id=gcp_conn_id,
+ impersonation_chain=impersonation_chain,
+ **kwargs,
+ )
+
+ async def get_credentials(self) -> Credentials:
+ return (await self.get_sync_hook()).get_credentials()
+
Review Comment:
```suggestion
async def get_credentials(self) -> Credentials:
sync_hook = await self.get_sync_hook()
return sync_hook.get_credentials()
```
nitpick: unify how these get` methods are implemented
##########
airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py:
##########
@@ -96,6 +101,10 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
+ :param deferrable: If True, run the task in the deferrable mode.
Review Comment:
```suggestion
:param deferrable: If True, run the task in the deferrable mode.
```
##########
airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py:
##########
@@ -243,15 +245,110 @@ def run_pipeline_job(
location=region,
failure_policy=failure_policy,
)
+ self._pipeline_job.submit(
+ service_account=service_account,
+ network=network,
+ create_request_timeout=create_request_timeout,
+ experiment=experiment,
+ )
+ self._pipeline_job.wait()
+
+ return self._pipeline_job
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def submit_pipeline_job(
+ self,
+ project_id: str,
+ region: str,
+ display_name: str,
+ template_path: str,
+ job_id: str | None = None,
+ pipeline_root: str | None = None,
+ parameter_values: dict[str, Any] | None = None,
+ input_artifacts: dict[str, str] | None = None,
+ enable_caching: bool | None = None,
+ encryption_spec_key_name: str | None = None,
+ labels: dict[str, str] | None = None,
+ failure_policy: str | None = None,
+ # START: run param
Review Comment:
I'm good with not adding it, either. What I thought was to avoid users from
passing wrong arg due to wrong order
##########
tests/providers/google/cloud/operators/test_vertex_ai.py:
##########
@@ -2041,9 +2042,70 @@ def test_execute(self, mock_hook, to_dict_mock):
experiment=None,
)
+ @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJobHook"))
+ def test_execute_enters_deferred_state(self, mock_hook):
+ task = RunPipelineJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ display_name=DISPLAY_NAME,
+ template_path=TEST_TEMPLATE_PATH,
+ job_id=TEST_PIPELINE_JOB_ID,
+ deferrable=True,
+ )
+ mock_hook.return_value.exists.return_value = False
+ with pytest.raises(TaskDeferred) as exc:
+ task.execute(context={"ti": mock.MagicMock()})
+ assert isinstance(exc.value.trigger, RunPipelineJobTrigger), "Trigger
is not a RunPipelineJobTrigger"
+
+ @mock.patch(
+
"airflow.providers.google.cloud.operators.vertex_ai.pipeline_job.RunPipelineJobOperator.xcom_push"
+ )
Review Comment:
```suggestion
@mock.patch(VERTEX_AI_PATH.format("pipeline_job.RunPipelineJobOperator.xcom_push"))
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]