This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new fd8a05739f Bugfix to correct GCSHook being called even when not required with BeamRunPythonPipelineOperator (#38716) fd8a05739f is described below commit fd8a05739f945643b5023db15d51a97459109a02 Author: Zack Strathe <59071005+zstra...@users.noreply.github.com> AuthorDate: Fri Apr 19 03:40:19 2024 -0500 Bugfix to correct GCSHook being called even when not required with BeamRunPythonPipelineOperator (#38716) * Bugfix to correct GCSHook being called even when not required with BeamRunPythonPipelineOperator * remove unneccary check for GCSHook and add unit test for BeamRunPythonPipelineOperator to ensure that GCSHook is only called when necessary * Split out unit tests for TestBeamRunPythonPipelineOperator with GCSHook 'gs://' arg prefixes * Fix formatting --- airflow/providers/apache/beam/operators/beam.py | 3 +- tests/providers/apache/beam/operators/test_beam.py | 73 ++++++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py index e88923bc05..62f650f19a 100644 --- a/airflow/providers/apache/beam/operators/beam.py +++ b/airflow/providers/apache/beam/operators/beam.py @@ -364,11 +364,12 @@ class BeamRunPythonPipelineOperator(BeamBasePipelineOperator): def execute_sync(self, context: Context): with ExitStack() as exit_stack: - gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id) if self.py_file.lower().startswith("gs://"): + gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id) tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.py_file)) self.py_file = tmp_gcs_file.name if self.snake_case_pipeline_options.get("requirements_file", "").startswith("gs://"): + gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id) tmp_req_file = exit_stack.enter_context( gcs_hook.provide_file(object_url=self.snake_case_pipeline_options["requirements_file"]) ) diff --git a/tests/providers/apache/beam/operators/test_beam.py b/tests/providers/apache/beam/operators/test_beam.py index f7ca9649fb..a6a4c31c77 100644 --- a/tests/providers/apache/beam/operators/test_beam.py +++ b/tests/providers/apache/beam/operators/test_beam.py @@ -256,6 +256,79 @@ class TestBeamRunPythonPipelineOperator: op.on_kill() dataflow_cancel_job.assert_not_called() + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + def test_execute_gcs_hook_not_called_without_gs_prefix(self, mock_gcs_hook, _): + """ + Test that execute method does not call GCSHook when neither py_file nor requirements_file + starts with 'gs://'. (i.e., running pipeline entirely locally) + """ + local_test_op_args = { + "task_id": TASK_ID, + "py_file": "local_file.py", + "py_options": ["-m"], + "default_pipeline_options": { + "project": TEST_PROJECT, + "requirements_file": "local_requirements.txt", + }, + "pipeline_options": {"output": "test_local/output", "labels": {"foo": "bar"}}, + } + + op = BeamRunPythonPipelineOperator(**local_test_op_args) + context_mock = mock.MagicMock() + + op.execute(context_mock) + mock_gcs_hook.assert_not_called() + + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + def test_execute_gcs_hook_called_with_gs_prefix_py_file(self, mock_gcs_hook, _): + """ + Test that execute method calls GCSHook when only 'py_file' starts with 'gs://'. + """ + local_test_op_args = { + "task_id": TASK_ID, + "py_file": "gs://gcs_file.py", + "py_options": ["-m"], + "default_pipeline_options": { + "project": TEST_PROJECT, + "requirements_file": "local_requirements.txt", + }, + "pipeline_options": {"output": "test_local/output", "labels": {"foo": "bar"}}, + } + op = BeamRunPythonPipelineOperator(**local_test_op_args) + context_mock = mock.MagicMock() + + op.execute(context_mock) + mock_gcs_hook.assert_called_once() + + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + def test_execute_gcs_hook_called_with_gs_prefix_pipeline_requirements(self, mock_gcs_hook, _): + """ + Test that execute method calls GCSHook when only pipeline_options 'requirements_file' starts with + 'gs://'. + Note: "pipeline_options" is merged with and overrides keys in "default_pipeline_options" when + BeamRunPythonPipelineOperator is instantiated, so testing GCS 'requirements_file' specified + in "pipeline_options" + """ + local_test_op_args = { + "task_id": TASK_ID, + "py_file": "local_file.py", + "py_options": ["-m"], + "default_pipeline_options": { + "project": TEST_PROJECT, + "requirements_file": "gs://gcs_requirements.txt", + }, + "pipeline_options": {"output": "test_local/output", "labels": {"foo": "bar"}}, + } + + op = BeamRunPythonPipelineOperator(**local_test_op_args) + context_mock = mock.MagicMock() + + op.execute(context_mock) + mock_gcs_hook.assert_called_once() + class TestBeamRunJavaPipelineOperator: @pytest.fixture(autouse=True)