Lee-W commented on code in PR #36594:
URL: https://github.com/apache/airflow/pull/36594#discussion_r1451428350


##########
airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py:
##########
@@ -413,3 +428,106 @@ def delete_hyperparameter_tuning_job(
             metadata=metadata,
         )
         return result
+
+
+class HyperparameterTuningJobAsyncHook(GoogleBaseHook):
+    """Async hook for Google Cloud Vertex AI Hyperparameter Tuning Job APIs."""
+
+    def __init__(
+        self,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ):
+        super().__init__(
+            gcp_conn_id=gcp_conn_id,
+            impersonation_chain=impersonation_chain,
+            **kwargs,
+        )
+        self._client: JobServiceAsyncClient | None = None
+
+    def get_job_service_client(self, region: str | None = None) -> 
JobServiceAsyncClient:
+        """
+        Retrieves Vertex AI async client.
+
+        :return: Google Cloud Vertex AI client object.
+        """
+        if not self._client:
+            endpoint = f"{region}-aiplatform.googleapis.com:443" if region and 
region != "global" else None
+            self._client = JobServiceAsyncClient(
+                credentials=self.get_credentials(),
+                client_info=CLIENT_INFO,
+                client_options=ClientOptions(api_endpoint=endpoint),
+            )
+        return self._client
+
+    async def get_hyperparameter_tuning_job(
+        self,
+        project_id: str,
+        location: str,
+        job_id: str,
+        retry: AsyncRetry | _MethodDefault = DEFAULT,
+        timeout: float | None = None,
+        metadata: Sequence[tuple[str, str]] = (),
+    ) -> types.HyperparameterTuningJob:
+        """
+        Retrieves a hyperparameter tuning job.
+
+        :param project_id: Required. The ID of the Google Cloud project that 
the job belongs to.
+        :param location: Required. The ID of the Google Cloud region that the 
job belongs to.
+        :param job_id: Required. The hyperparameter tuning job id.
+        :param retry: Designation of what errors, if any, should be retried.
+        :param timeout: The timeout for this request.
+        :param metadata: Strings which should be sent along with the request 
as metadata.
+        """
+        client: JobServiceAsyncClient = 
self.get_job_service_client(region=location)
+        job_name = client.hyperparameter_tuning_job_path(project_id, location, 
job_id)
+
+        result = await client.get_hyperparameter_tuning_job(
+            request={
+                "name": job_name,
+            },
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+
+        return result
+
+    async def wait_hyperparameter_tuning_job(
+        self,
+        project_id: str,
+        location: str,
+        job_id: str,
+        retry: AsyncRetry | _MethodDefault = DEFAULT,
+        timeout: float | None = None,
+        metadata: Sequence[tuple[str, str]] = (),
+        poll_interval: int = 10,
+    ) -> types.HyperparameterTuningJob:
+        statuses_complete = {
+            JobState.JOB_STATE_CANCELLED,
+            JobState.JOB_STATE_FAILED,
+            JobState.JOB_STATE_PAUSED,
+            JobState.JOB_STATE_SUCCEEDED,
+        }
+        while True:
+            try:
+                self.log.info("Requesting hyperparameter tuning job with id 
%s", job_id)
+                job: types.HyperparameterTuningJob = await 
self.get_hyperparameter_tuning_job(
+                    project_id=project_id,
+                    location=location,
+                    job_id=job_id,
+                    retry=retry,
+                    timeout=timeout,
+                    metadata=metadata,
+                )
+            except Exception as ex:
+                self.log.exception("Exception occurred while requesting job 
%s", job_id)
+                raise AirflowException(ex)
+
+            self.log.info("Status of the hyperparameter tuning job %s is %s", 
job.name, job.state.name)
+            if job.state in statuses_complete:
+                return job
+            else:
+                self.log.info("Sleeping for %s seconds.", poll_interval)
+                await asyncio.sleep(poll_interval)

Review Comment:
   ```suggestion
   
               self.log.info("Sleeping for %s seconds.", poll_interval)
               await asyncio.sleep(poll_interval)
   ```



##########
airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py:
##########
@@ -413,3 +428,106 @@ def delete_hyperparameter_tuning_job(
             metadata=metadata,
         )
         return result
+
+
+class HyperparameterTuningJobAsyncHook(GoogleBaseHook):
+    """Async hook for Google Cloud Vertex AI Hyperparameter Tuning Job APIs."""
+
+    def __init__(
+        self,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ):
+        super().__init__(
+            gcp_conn_id=gcp_conn_id,
+            impersonation_chain=impersonation_chain,
+            **kwargs,
+        )
+        self._client: JobServiceAsyncClient | None = None
+
+    def get_job_service_client(self, region: str | None = None) -> 
JobServiceAsyncClient:

Review Comment:
   Should we try something like 
https://github.com/apache/airflow/blob/c7ade012cb81158e2bc11b36febcfae82bc759ab/airflow/providers/amazon/aws/operators/batch.py#L224-L231



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to