This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 90233bc7cb Added impersonation_chain for 
DataflowStartFlexTemplateOperator and DataflowStartSqlJobOperator (#24046)
90233bc7cb is described below

commit 90233bc7cbb95d7de6e4de3b7b1206eebf5ad28c
Author: Ɓukasz Wyszomirski <[email protected]>
AuthorDate: Sat Jun 4 23:24:35 2022 +0200

    Added impersonation_chain for DataflowStartFlexTemplateOperator and 
DataflowStartSqlJobOperator (#24046)
---
 airflow/providers/google/cloud/hooks/dataflow.py   | 23 ++++++++++++++++++----
 .../providers/google/cloud/operators/dataflow.py   | 22 +++++++++++++++++++++
 .../providers/google/cloud/hooks/test_dataflow.py  | 17 +++-------------
 .../google/cloud/operators/test_dataflow.py        | 13 +++++++++++-
 4 files changed, 56 insertions(+), 19 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/dataflow.py 
b/airflow/providers/google/cloud/hooks/dataflow.py
index 66b259838e..60ffb1e8eb 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -972,16 +972,31 @@ class DataflowHook(GoogleBaseHook):
         :param on_new_job_callback: Callback called when the job is known.
         :return: the new job object
         """
+        gcp_options = [
+            f"--project={project_id}",
+            "--format=value(job.id)",
+            f"--job-name={job_name}",
+            f"--region={location}",
+        ]
+
+        if self.impersonation_chain:
+            if isinstance(self.impersonation_chain, str):
+                impersonation_account = self.impersonation_chain
+            elif len(self.impersonation_chain) == 1:
+                impersonation_account = self.impersonation_chain[0]
+            else:
+                raise AirflowException(
+                    "Chained list of accounts is not supported, please specify 
only one service account"
+                )
+            
gcp_options.append(f"--impersonate-service-account={impersonation_account}")
+
         cmd = [
             "gcloud",
             "dataflow",
             "sql",
             "query",
             query,
-            f"--project={project_id}",
-            "--format=value(job.id)",
-            f"--job-name={job_name}",
-            f"--region={location}",
+            *gcp_options,
             *(beam_options_to_args(options)),
         ]
         self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c 
in cmd))
diff --git a/airflow/providers/google/cloud/operators/dataflow.py 
b/airflow/providers/google/cloud/operators/dataflow.py
index 98a17d800c..677bc94940 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -727,6 +727,14 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
 
         If you in your pipeline do not call the wait_for_pipeline method, and 
pass wait_until_finish=False
         to the operator, the second loop will check once is job not in 
terminal state and exit the loop.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
     """
 
     template_fields: Sequence[str] = ("body", "location", "project_id", 
"gcp_conn_id")
@@ -742,6 +750,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
         drain_pipeline: bool = False,
         cancel_timeout: Optional[int] = 10 * 60,
         wait_until_finished: Optional[bool] = None,
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
         *args,
         **kwargs,
     ) -> None:
@@ -756,6 +765,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
         self.wait_until_finished = wait_until_finished
         self.job = None
         self.hook: Optional[DataflowHook] = None
+        self.impersonation_chain = impersonation_chain
 
     def execute(self, context: 'Context'):
         self.hook = DataflowHook(
@@ -764,6 +774,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
             drain_pipeline=self.drain_pipeline,
             cancel_timeout=self.cancel_timeout,
             wait_until_finished=self.wait_until_finished,
+            impersonation_chain=self.impersonation_chain,
         )
 
         def set_current_job(current_job):
@@ -821,6 +832,14 @@ class DataflowStartSqlJobOperator(BaseOperator):
     :param drain_pipeline: Optional, set to True if want to stop streaming job 
by draining it
         instead of canceling during killing task instance. See:
         https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
     """
 
     template_fields: Sequence[str] = (
@@ -843,6 +862,7 @@ class DataflowStartSqlJobOperator(BaseOperator):
         gcp_conn_id: str = "google_cloud_default",
         delegate_to: Optional[str] = None,
         drain_pipeline: bool = False,
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
         *args,
         **kwargs,
     ) -> None:
@@ -855,6 +875,7 @@ class DataflowStartSqlJobOperator(BaseOperator):
         self.gcp_conn_id = gcp_conn_id
         self.delegate_to = delegate_to
         self.drain_pipeline = drain_pipeline
+        self.impersonation_chain = impersonation_chain
         self.job = None
         self.hook: Optional[DataflowHook] = None
 
@@ -863,6 +884,7 @@ class DataflowStartSqlJobOperator(BaseOperator):
             gcp_conn_id=self.gcp_conn_id,
             delegate_to=self.delegate_to,
             drain_pipeline=self.drain_pipeline,
+            impersonation_chain=self.impersonation_chain,
         )
 
         def set_current_job(current_job):
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py 
b/tests/providers/google/cloud/hooks/test_dataflow.py
index 98aeebe2fb..b24eeb566b 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -173,20 +173,10 @@ class TestFallbackToVariables(unittest.TestCase):
             FixtureFallback().test_fn({'project': "TEST"}, "TEST2")
 
 
-def mock_init(
-    self,
-    gcp_conn_id,
-    delegate_to=None,
-    impersonation_chain=None,
-):
-    pass
-
-
 class TestDataflowHook(unittest.TestCase):
     def setUp(self):
-        with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), 
new=mock_init):
-            self.dataflow_hook = DataflowHook(gcp_conn_id='test')
-            self.dataflow_hook.beam_hook = MagicMock()
+        self.dataflow_hook = DataflowHook(gcp_conn_id='google_cloud_default')
+        self.dataflow_hook.beam_hook = MagicMock()
 
     
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.DataflowHook._authorize")
     @mock.patch("airflow.providers.google.cloud.hooks.dataflow.build")
@@ -792,8 +782,7 @@ class TestDataflowHook(unittest.TestCase):
 
 class TestDataflowTemplateHook(unittest.TestCase):
     def setUp(self):
-        with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), 
new=mock_init):
-            self.dataflow_hook = DataflowHook(gcp_conn_id='test')
+        self.dataflow_hook = DataflowHook(gcp_conn_id='google_cloud_default')
 
     @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID)
     @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py 
b/tests/providers/google/cloud/operators/test_dataflow.py
index 333bf7c272..2cc37cb589 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -496,6 +496,14 @@ class 
TestDataflowStartFlexTemplateOperator(unittest.TestCase):
             location=TEST_LOCATION,
         )
         start_flex_template.execute(mock.MagicMock())
+        mock_dataflow.assert_called_once_with(
+            gcp_conn_id='google_cloud_default',
+            delegate_to=None,
+            drain_pipeline=False,
+            cancel_timeout=600,
+            wait_until_finished=None,
+            impersonation_chain=None,
+        )
         mock_dataflow.return_value.start_flex_template.assert_called_once_with(
             body={"launchParameter": TEST_FLEX_PARAMETERS},
             location=TEST_LOCATION,
@@ -533,7 +541,10 @@ class TestDataflowSqlOperator(unittest.TestCase):
 
         start_sql.execute(mock.MagicMock())
         mock_hook.assert_called_once_with(
-            gcp_conn_id='google_cloud_default', delegate_to=None, 
drain_pipeline=False
+            gcp_conn_id='google_cloud_default',
+            delegate_to=None,
+            drain_pipeline=False,
+            impersonation_chain=None,
         )
         mock_hook.return_value.start_sql_job.assert_called_once_with(
             job_name=TEST_SQL_JOB_NAME,

Reply via email to