This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 90233bc7cb Added impersonation_chain for
DataflowStartFlexTemplateOperator and DataflowStartSqlJobOperator (#24046)
90233bc7cb is described below
commit 90233bc7cbb95d7de6e4de3b7b1206eebf5ad28c
Author: Ćukasz Wyszomirski <[email protected]>
AuthorDate: Sat Jun 4 23:24:35 2022 +0200
Added impersonation_chain for DataflowStartFlexTemplateOperator and
DataflowStartSqlJobOperator (#24046)
---
airflow/providers/google/cloud/hooks/dataflow.py | 23 ++++++++++++++++++----
.../providers/google/cloud/operators/dataflow.py | 22 +++++++++++++++++++++
.../providers/google/cloud/hooks/test_dataflow.py | 17 +++-------------
.../google/cloud/operators/test_dataflow.py | 13 +++++++++++-
4 files changed, 56 insertions(+), 19 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/dataflow.py
b/airflow/providers/google/cloud/hooks/dataflow.py
index 66b259838e..60ffb1e8eb 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -972,16 +972,31 @@ class DataflowHook(GoogleBaseHook):
:param on_new_job_callback: Callback called when the job is known.
:return: the new job object
"""
+ gcp_options = [
+ f"--project={project_id}",
+ "--format=value(job.id)",
+ f"--job-name={job_name}",
+ f"--region={location}",
+ ]
+
+ if self.impersonation_chain:
+ if isinstance(self.impersonation_chain, str):
+ impersonation_account = self.impersonation_chain
+ elif len(self.impersonation_chain) == 1:
+ impersonation_account = self.impersonation_chain[0]
+ else:
+ raise AirflowException(
+ "Chained list of accounts is not supported, please specify
only one service account"
+ )
+
gcp_options.append(f"--impersonate-service-account={impersonation_account}")
+
cmd = [
"gcloud",
"dataflow",
"sql",
"query",
query,
- f"--project={project_id}",
- "--format=value(job.id)",
- f"--job-name={job_name}",
- f"--region={location}",
+ *gcp_options,
*(beam_options_to_args(options)),
]
self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c
in cmd))
diff --git a/airflow/providers/google/cloud/operators/dataflow.py
b/airflow/providers/google/cloud/operators/dataflow.py
index 98a17d800c..677bc94940 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -727,6 +727,14 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
If you in your pipeline do not call the wait_for_pipeline method, and
pass wait_until_finish=False
to the operator, the second loop will check once is job not in
terminal state and exit the loop.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
"""
template_fields: Sequence[str] = ("body", "location", "project_id",
"gcp_conn_id")
@@ -742,6 +750,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
drain_pipeline: bool = False,
cancel_timeout: Optional[int] = 10 * 60,
wait_until_finished: Optional[bool] = None,
+ impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
*args,
**kwargs,
) -> None:
@@ -756,6 +765,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
self.wait_until_finished = wait_until_finished
self.job = None
self.hook: Optional[DataflowHook] = None
+ self.impersonation_chain = impersonation_chain
def execute(self, context: 'Context'):
self.hook = DataflowHook(
@@ -764,6 +774,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
drain_pipeline=self.drain_pipeline,
cancel_timeout=self.cancel_timeout,
wait_until_finished=self.wait_until_finished,
+ impersonation_chain=self.impersonation_chain,
)
def set_current_job(current_job):
@@ -821,6 +832,14 @@ class DataflowStartSqlJobOperator(BaseOperator):
:param drain_pipeline: Optional, set to True if want to stop streaming job
by draining it
instead of canceling during killing task instance. See:
https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
"""
template_fields: Sequence[str] = (
@@ -843,6 +862,7 @@ class DataflowStartSqlJobOperator(BaseOperator):
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
drain_pipeline: bool = False,
+ impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
*args,
**kwargs,
) -> None:
@@ -855,6 +875,7 @@ class DataflowStartSqlJobOperator(BaseOperator):
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.drain_pipeline = drain_pipeline
+ self.impersonation_chain = impersonation_chain
self.job = None
self.hook: Optional[DataflowHook] = None
@@ -863,6 +884,7 @@ class DataflowStartSqlJobOperator(BaseOperator):
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
drain_pipeline=self.drain_pipeline,
+ impersonation_chain=self.impersonation_chain,
)
def set_current_job(current_job):
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py
b/tests/providers/google/cloud/hooks/test_dataflow.py
index 98aeebe2fb..b24eeb566b 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -173,20 +173,10 @@ class TestFallbackToVariables(unittest.TestCase):
FixtureFallback().test_fn({'project': "TEST"}, "TEST2")
-def mock_init(
- self,
- gcp_conn_id,
- delegate_to=None,
- impersonation_chain=None,
-):
- pass
-
-
class TestDataflowHook(unittest.TestCase):
def setUp(self):
- with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'),
new=mock_init):
- self.dataflow_hook = DataflowHook(gcp_conn_id='test')
- self.dataflow_hook.beam_hook = MagicMock()
+ self.dataflow_hook = DataflowHook(gcp_conn_id='google_cloud_default')
+ self.dataflow_hook.beam_hook = MagicMock()
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.DataflowHook._authorize")
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.build")
@@ -792,8 +782,7 @@ class TestDataflowHook(unittest.TestCase):
class TestDataflowTemplateHook(unittest.TestCase):
def setUp(self):
- with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'),
new=mock_init):
- self.dataflow_hook = DataflowHook(gcp_conn_id='test')
+ self.dataflow_hook = DataflowHook(gcp_conn_id='google_cloud_default')
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID)
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py
b/tests/providers/google/cloud/operators/test_dataflow.py
index 333bf7c272..2cc37cb589 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -496,6 +496,14 @@ class
TestDataflowStartFlexTemplateOperator(unittest.TestCase):
location=TEST_LOCATION,
)
start_flex_template.execute(mock.MagicMock())
+ mock_dataflow.assert_called_once_with(
+ gcp_conn_id='google_cloud_default',
+ delegate_to=None,
+ drain_pipeline=False,
+ cancel_timeout=600,
+ wait_until_finished=None,
+ impersonation_chain=None,
+ )
mock_dataflow.return_value.start_flex_template.assert_called_once_with(
body={"launchParameter": TEST_FLEX_PARAMETERS},
location=TEST_LOCATION,
@@ -533,7 +541,10 @@ class TestDataflowSqlOperator(unittest.TestCase):
start_sql.execute(mock.MagicMock())
mock_hook.assert_called_once_with(
- gcp_conn_id='google_cloud_default', delegate_to=None,
drain_pipeline=False
+ gcp_conn_id='google_cloud_default',
+ delegate_to=None,
+ drain_pipeline=False,
+ impersonation_chain=None,
)
mock_hook.return_value.start_sql_job.assert_called_once_with(
job_name=TEST_SQL_JOB_NAME,