This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f2ebc292fe Provide missing project id and creds for TabularDataset
(#31991)
f2ebc292fe is described below
commit f2ebc292fe63d2ddd0686d90c3acc0630f017a07
Author: Hussein Awala <[email protected]>
AuthorDate: Mon Jun 19 05:53:05 2023 +0200
Provide missing project id and creds for TabularDataset (#31991)
---
.../providers/google/cloud/operators/vertex_ai/auto_ml.py | 7 ++++++-
tests/providers/google/cloud/operators/test_vertex_ai.py | 12 ++++++++++--
2 files changed, 16 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
index 80e573a1d6..a3bcb2158d 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
@@ -352,11 +352,16 @@ class
CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
+ credentials, _ = self.hook.get_credentials_and_project_id()
model, training_id = self.hook.create_auto_ml_tabular_training_job(
project_id=self.project_id,
region=self.region,
display_name=self.display_name,
- dataset=datasets.TabularDataset(dataset_name=self.dataset_id),
+ dataset=datasets.TabularDataset(
+ dataset_name=self.dataset_id,
+ project=self.project_id,
+ credentials=credentials,
+ ),
target_column=self.target_column,
optimization_prediction_type=self.optimization_prediction_type,
optimization_objective=self.optimization_objective,
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py
b/tests/providers/google/cloud/operators/test_vertex_ai.py
index f4b3ad154d..f2921942ab 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -17,6 +17,7 @@
from __future__ import annotations
from unittest import mock
+from unittest.mock import MagicMock
from google.api_core.gapic_v1.method import DEFAULT
from google.api_core.retry import Retry
@@ -783,7 +784,12 @@ class TestVertexAICreateAutoMLTabularTrainingJobOperator:
@mock.patch("google.cloud.aiplatform.datasets.TabularDataset")
@mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
def test_execute(self, mock_hook, mock_dataset):
-
mock_hook.return_value.create_auto_ml_tabular_training_job.return_value =
(None, "training_id")
+ mock_hook.return_value = MagicMock(
+ **{
+ "create_auto_ml_tabular_training_job.return_value": (None,
"training_id"),
+ "get_credentials_and_project_id.return_value": ("creds",
"project_id"),
+ }
+ )
op = CreateAutoMLTabularTrainingJobOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -798,7 +804,9 @@ class TestVertexAICreateAutoMLTabularTrainingJobOperator:
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
- mock_dataset.assert_called_once_with(dataset_name=TEST_DATASET_ID)
+ mock_dataset.assert_called_once_with(
+ dataset_name=TEST_DATASET_ID, project=GCP_PROJECT,
credentials="creds"
+ )
mock_hook.return_value.create_auto_ml_tabular_training_job.assert_called_once_with(
project_id=GCP_PROJECT,
region=GCP_LOCATION,