This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f2ebc292fe Provide missing project id and creds for TabularDataset 
(#31991)
f2ebc292fe is described below

commit f2ebc292fe63d2ddd0686d90c3acc0630f017a07
Author: Hussein Awala <[email protected]>
AuthorDate: Mon Jun 19 05:53:05 2023 +0200

    Provide missing project id and creds for TabularDataset (#31991)
---
 .../providers/google/cloud/operators/vertex_ai/auto_ml.py    |  7 ++++++-
 tests/providers/google/cloud/operators/test_vertex_ai.py     | 12 ++++++++++--
 2 files changed, 16 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py 
b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
index 80e573a1d6..a3bcb2158d 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
@@ -352,11 +352,16 @@ class 
CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
+        credentials, _ = self.hook.get_credentials_and_project_id()
         model, training_id = self.hook.create_auto_ml_tabular_training_job(
             project_id=self.project_id,
             region=self.region,
             display_name=self.display_name,
-            dataset=datasets.TabularDataset(dataset_name=self.dataset_id),
+            dataset=datasets.TabularDataset(
+                dataset_name=self.dataset_id,
+                project=self.project_id,
+                credentials=credentials,
+            ),
             target_column=self.target_column,
             optimization_prediction_type=self.optimization_prediction_type,
             optimization_objective=self.optimization_objective,
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py 
b/tests/providers/google/cloud/operators/test_vertex_ai.py
index f4b3ad154d..f2921942ab 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 from unittest import mock
+from unittest.mock import MagicMock
 
 from google.api_core.gapic_v1.method import DEFAULT
 from google.api_core.retry import Retry
@@ -783,7 +784,12 @@ class TestVertexAICreateAutoMLTabularTrainingJobOperator:
     @mock.patch("google.cloud.aiplatform.datasets.TabularDataset")
     @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
     def test_execute(self, mock_hook, mock_dataset):
-        
mock_hook.return_value.create_auto_ml_tabular_training_job.return_value = 
(None, "training_id")
+        mock_hook.return_value = MagicMock(
+            **{
+                "create_auto_ml_tabular_training_job.return_value": (None, 
"training_id"),
+                "get_credentials_and_project_id.return_value": ("creds", 
"project_id"),
+            }
+        )
         op = CreateAutoMLTabularTrainingJobOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -798,7 +804,9 @@ class TestVertexAICreateAutoMLTabularTrainingJobOperator:
         )
         op.execute(context={"ti": mock.MagicMock()})
         mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, 
impersonation_chain=IMPERSONATION_CHAIN)
-        mock_dataset.assert_called_once_with(dataset_name=TEST_DATASET_ID)
+        mock_dataset.assert_called_once_with(
+            dataset_name=TEST_DATASET_ID, project=GCP_PROJECT, 
credentials="creds"
+        )
         
mock_hook.return_value.create_auto_ml_tabular_training_job.assert_called_once_with(
             project_id=GCP_PROJECT,
             region=GCP_LOCATION,

Reply via email to