This is an automated email from the ASF dual-hosted git repository. kamilbregula pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 41e94b475e Support impersonation service account parameter for Dataflow runner (#23961) 41e94b475e is described below commit 41e94b475e06f63db39b0943c9d9a7476367083c Author: Ćukasz Wyszomirski <wyszomir...@google.com> AuthorDate: Tue May 31 10:21:38 2022 +0200 Support impersonation service account parameter for Dataflow runner (#23961) --- airflow/providers/apache/beam/operators/beam.py | 15 +++++++++++++++ setup.py | 2 +- tests/providers/apache/beam/operators/test_beam.py | 9 ++++++--- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py index 082a9a7f6e..e7f7af5e23 100644 --- a/airflow/providers/apache/beam/operators/beam.py +++ b/airflow/providers/apache/beam/operators/beam.py @@ -51,6 +51,7 @@ class BeamDataflowMixin(metaclass=ABCMeta): dataflow_config: DataflowConfiguration gcp_conn_id: str delegate_to: Optional[str] + dataflow_support_impersonation: bool = True def _set_dataflow( self, @@ -91,6 +92,13 @@ class BeamDataflowMixin(metaclass=ABCMeta): pipeline_options[job_name_key] = job_name if self.dataflow_config.service_account: pipeline_options["serviceAccount"] = self.dataflow_config.service_account + if self.dataflow_support_impersonation and self.dataflow_config.impersonation_chain: + if isinstance(self.dataflow_config.impersonation_chain, list): + pipeline_options["impersonateServiceAccount"] = ",".join( + self.dataflow_config.impersonation_chain + ) + else: + pipeline_options["impersonateServiceAccount"] = self.dataflow_config.impersonation_chain pipeline_options["project"] = self.dataflow_config.project_id pipeline_options["region"] = self.dataflow_config.location pipeline_options.setdefault("labels", {}).update( @@ -550,6 +558,13 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator): **kwargs, ) + if self.dataflow_config.impersonation_chain: + self.log.info( + "Impersonation chain parameter is not supported for Apache Beam GO SDK and will be skipped " + "in the execution" + ) + self.dataflow_support_impersonation = False + self.go_file = go_file self.should_init_go_module = False self.pipeline_options.setdefault("labels", {}).update( diff --git a/setup.py b/setup.py index 3094dca782..e73a88d6e7 100644 --- a/setup.py +++ b/setup.py @@ -202,7 +202,7 @@ amazon = [ 'mypy-boto3-redshift-data>=1.21.0', ] apache_beam = [ - 'apache-beam>=2.33.0', + 'apache-beam>=2.39.0', ] arangodb = ['python-arango>=7.3.2'] asana = ['asana>=0.10'] diff --git a/tests/providers/apache/beam/operators/test_beam.py b/tests/providers/apache/beam/operators/test_beam.py index 6fd83d7268..23700dc71b 100644 --- a/tests/providers/apache/beam/operators/test_beam.py +++ b/tests/providers/apache/beam/operators/test_beam.py @@ -47,6 +47,7 @@ EXPECTED_ADDITIONAL_OPTIONS = { 'output': 'gs://test/output', 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION}, } +TEST_IMPERSONATION_ACCOUNT = "t...@impersonation.com" class TestBeamRunPythonPipelineOperator(unittest.TestCase): @@ -104,7 +105,7 @@ class TestBeamRunPythonPipelineOperator(unittest.TestCase): """Test DataflowHook is created and the right args are passed to start_python_dataflow. """ - dataflow_config = DataflowConfiguration() + dataflow_config = DataflowConfiguration(impersonation_chain=TEST_IMPERSONATION_ACCOUNT) self.operator.runner = "DataflowRunner" self.operator.dataflow_config = dataflow_config gcs_provide_file = gcs_hook.return_value.provide_file @@ -126,6 +127,7 @@ class TestBeamRunPythonPipelineOperator(unittest.TestCase): 'output': 'gs://test/output', 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION}, 'region': 'us-central1', + 'impersonate_service_account': TEST_IMPERSONATION_ACCOUNT, } gcs_provide_file.assert_called_once_with(object_url=PY_FILE) persist_link_mock.assert_called_once_with( @@ -223,7 +225,7 @@ class TestBeamRunJavaPipelineOperator(unittest.TestCase): """Test DataflowHook is created and the right args are passed to start_java_dataflow. """ - dataflow_config = DataflowConfiguration() + dataflow_config = DataflowConfiguration(impersonation_chain="t...@impersonation.com") self.operator.runner = "DataflowRunner" self.operator.dataflow_config = dataflow_config gcs_provide_file = gcs_hook.return_value.provide_file @@ -248,6 +250,7 @@ class TestBeamRunJavaPipelineOperator(unittest.TestCase): 'region': 'us-central1', 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION}, 'output': 'gs://test/output', + 'impersonateServiceAccount': TEST_IMPERSONATION_ACCOUNT, } persist_link_mock.assert_called_once_with( self.operator, @@ -374,7 +377,7 @@ class TestBeamRunGoPipelineOperator(unittest.TestCase): """Test DataflowHook is created and the right args are passed to start_go_dataflow. """ - dataflow_config = DataflowConfiguration() + dataflow_config = DataflowConfiguration(impersonation_chain="t...@impersonation.com") self.operator.runner = "DataflowRunner" self.operator.dataflow_config = dataflow_config gcs_provide_file = gcs_hook.return_value.provide_file