This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 46666af9ec Refactor CreateHyperparameterTuningJobOperator (#37938)
46666af9ec is described below
commit 46666af9ecc0f183d7bf0845a646f24fbd91c697
Author: max <[email protected]>
AuthorDate: Thu Mar 7 23:52:22 2024 +0100
Refactor CreateHyperparameterTuningJobOperator (#37938)
---
.../vertex_ai/hyperparameter_tuning_job.py | 77 ++++++++--------------
.../google/cloud/operators/test_vertex_ai.py | 25 ++-----
2 files changed, 33 insertions(+), 69 deletions(-)
diff --git
a/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py
b/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py
index 5e5e20aa9f..2da2bf3394 100644
---
a/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py
+++
b/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py
@@ -20,14 +20,15 @@
from __future__ import annotations
+import warnings
from typing import TYPE_CHECKING, Any, Sequence
from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
-from google.cloud.aiplatform_v1.types import HyperparameterTuningJob
+from google.cloud.aiplatform_v1 import types
from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job
import (
HyperparameterTuningJobHook,
)
@@ -40,7 +41,7 @@ from airflow.providers.google.cloud.triggers.vertex_ai import
CreateHyperparamet
if TYPE_CHECKING:
from google.api_core.retry import Retry
- from google.cloud.aiplatform import gapic, hyperparameter_tuning
+ from google.cloud.aiplatform import HyperparameterTuningJob, gapic,
hyperparameter_tuning
from airflow.utils.context import Context
@@ -127,8 +128,8 @@ class
CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
`service_account` is required with provided `tensorboard`. For more
information on configuring
your service account please visit:
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
- :param sync: Whether to execute this method synchronously. If False, this
method will unblock, and it
- will be executed in a concurrent Future.
+ :param sync: (Deprecated) Whether to execute this method synchronously. If
False, this method will
+ unblock, and it will be executed in a concurrent Future.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using
short-term
credentials, or chained list of accounts required to get the
access_token
@@ -138,8 +139,7 @@ class
CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
- :param deferrable: Run operator in the deferrable mode. Note that it
requires calling the operator
- with `sync=False` parameter.
+ :param deferrable: Run operator in the deferrable mode.
:param poll_interval: Interval size which defines how often job status is
checked in deferrable mode.
"""
@@ -221,19 +221,18 @@ class
CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
self.poll_interval = poll_interval
def execute(self, context: Context):
- if self.deferrable and self.sync:
- raise AirflowException(
- "Deferrable mode can be used only with sync=False option. "
- "If you are willing to run the operator in deferrable mode,
please, set sync=False. "
- "Otherwise, disable deferrable mode `deferrable=False`."
- )
+ warnings.warn(
+ "The 'sync' parameter is deprecated and will be removed after
01.09.2024.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
self.log.info("Creating Hyperparameter Tuning job")
self.hook = HyperparameterTuningJobHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
- result = self.hook.create_hyperparameter_tuning_job(
+ hyperparameter_tuning_job: HyperparameterTuningJob =
self.hook.create_hyperparameter_tuning_job(
project_id=self.project_id,
region=self.region,
display_name=self.display_name,
@@ -259,14 +258,19 @@ class
CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
restart_job_on_worker_restart=self.restart_job_on_worker_restart,
enable_web_access=self.enable_web_access,
tensorboard=self.tensorboard,
- sync=self.sync,
- wait_job_completed=not self.deferrable,
+ sync=False,
+ wait_job_completed=False,
)
- hyperparameter_tuning_job = result.to_dict()
- hyperparameter_tuning_job_id =
self.hook.extract_hyperparameter_tuning_job_id(
- hyperparameter_tuning_job
+ hyperparameter_tuning_job.wait_for_resource_creation()
+ hyperparameter_tuning_job_id = hyperparameter_tuning_job.name
+ self.log.info("Hyperparameter Tuning job was created. Job id: %s",
hyperparameter_tuning_job_id)
+
+ self.xcom_push(context, key="hyperparameter_tuning_job_id",
value=hyperparameter_tuning_job_id)
+ VertexAITrainingLink.persist(
+ context=context, task_instance=self,
training_id=hyperparameter_tuning_job_id
)
+
if self.deferrable:
self.defer(
trigger=CreateHyperparameterTuningJobTrigger(
@@ -279,14 +283,10 @@ class
CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
),
method_name="execute_complete",
)
+ return
- self.log.info("Hyperparameter Tuning job was created. Job id: %s",
hyperparameter_tuning_job_id)
-
- self.xcom_push(context, key="hyperparameter_tuning_job_id",
value=hyperparameter_tuning_job_id)
- VertexAITrainingLink.persist(
- context=context, task_instance=self,
training_id=hyperparameter_tuning_job_id
- )
- return hyperparameter_tuning_job
+ hyperparameter_tuning_job.wait_for_completion()
+ return hyperparameter_tuning_job.to_dict()
def on_kill(self) -> None:
"""Act as a callback called when the operator is killed; cancel any
running job."""
@@ -298,26 +298,7 @@ class
CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
raise AirflowException(event["message"])
job: dict[str, Any] = event["job"]
self.log.info("Hyperparameter tuning job %s created and completed
successfully.", job["name"])
- hook = HyperparameterTuningJobHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
- job_id = hook.extract_hyperparameter_tuning_job_id(job)
- self.xcom_push(
- context,
- key="hyperparameter_tuning_job_id",
- value=job_id,
- )
- self.xcom_push(
- context,
- key="training_conf",
- value={
- "training_conf_id": job_id,
- "region": self.region,
- "project_id": self.project_id,
- },
- )
- return event["job"]
+ return job
class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
@@ -387,7 +368,7 @@ class
GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
context=context, task_instance=self,
training_id=self.hyperparameter_tuning_job_id
)
self.log.info("Hyperparameter tuning job was gotten.")
- return HyperparameterTuningJob.to_dict(result)
+ return types.HyperparameterTuningJob.to_dict(result)
except NotFound:
self.log.info(
"The Hyperparameter tuning job %s does not exist.",
self.hyperparameter_tuning_job_id
@@ -532,4 +513,4 @@ class
ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
metadata=self.metadata,
)
VertexAIHyperparameterTuningJobListLink.persist(context=context,
task_instance=self)
- return [HyperparameterTuningJob.to_dict(result) for result in results]
+ return [types.HyperparameterTuningJob.to_dict(result) for result in
results]
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py
b/tests/providers/google/cloud/operators/test_vertex_ai.py
index 0c52ee7379..57864a0a02 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -1419,7 +1419,7 @@ class TestVertexAIUndeployModelOperator:
class TestVertexAICreateHyperparameterTuningJobOperator:
-
@mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJob.to_dict"))
+
@mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.types.HyperparameterTuningJob.to_dict"))
@mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJobHook"))
def test_execute(self, mock_hook, to_dict_mock):
op = CreateHyperparameterTuningJobOperator(
@@ -1464,7 +1464,7 @@ class TestVertexAICreateHyperparameterTuningJobOperator:
enable_web_access=False,
tensorboard=None,
sync=False,
- wait_job_completed=True,
+ wait_job_completed=False,
)
@mock.patch(
@@ -1511,11 +1511,8 @@ class TestVertexAICreateHyperparameterTuningJobOperator:
with pytest.raises(AirflowException):
op.execute(context={"ti": mock.MagicMock()})
- @mock.patch(
-
VERTEX_AI_PATH.format("hyperparameter_tuning_job.CreateHyperparameterTuningJobOperator.xcom_push")
- )
@mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJobHook"))
- def test_execute_complete(self, mock_hook, mock_xcom_push):
+ def test_execute_complete(self, mock_hook):
test_job_id = "test_job_id"
test_job = {"name": f"test/{test_job_id}"}
event = {
@@ -1544,20 +1541,6 @@ class TestVertexAICreateHyperparameterTuningJobOperator:
result = op.execute_complete(context=mock_context, event=event)
- mock_xcom_push.assert_has_calls(
- [
- call(mock_context, key="hyperparameter_tuning_job_id",
value=test_job_id),
- call(
- mock_context,
- key="training_conf",
- value={
- "training_conf_id": test_job_id,
- "region": GCP_LOCATION,
- "project_id": GCP_PROJECT,
- },
- ),
- ]
- )
assert result == test_job
def test_execute_complete_error(self):
@@ -1587,7 +1570,7 @@ class TestVertexAICreateHyperparameterTuningJobOperator:
class TestVertexAIGetHyperparameterTuningJobOperator:
-
@mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJob.to_dict"))
+
@mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.types.HyperparameterTuningJob.to_dict"))
@mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJobHook"))
def test_execute(self, mock_hook, to_dict_mock):
op = GetHyperparameterTuningJobOperator(