This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fd8a05739f Bugfix to correct GCSHook being called even when not
required with BeamRunPythonPipelineOperator (#38716)
fd8a05739f is described below
commit fd8a05739f945643b5023db15d51a97459109a02
Author: Zack Strathe <[email protected]>
AuthorDate: Fri Apr 19 03:40:19 2024 -0500
Bugfix to correct GCSHook being called even when not required with
BeamRunPythonPipelineOperator (#38716)
* Bugfix to correct GCSHook being called even when not required with
BeamRunPythonPipelineOperator
* remove unneccary check for GCSHook and add unit test for
BeamRunPythonPipelineOperator to ensure that GCSHook is only called when
necessary
* Split out unit tests for TestBeamRunPythonPipelineOperator with GCSHook
'gs://' arg prefixes
* Fix formatting
---
airflow/providers/apache/beam/operators/beam.py | 3 +-
tests/providers/apache/beam/operators/test_beam.py | 73 ++++++++++++++++++++++
2 files changed, 75 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/apache/beam/operators/beam.py
b/airflow/providers/apache/beam/operators/beam.py
index e88923bc05..62f650f19a 100644
--- a/airflow/providers/apache/beam/operators/beam.py
+++ b/airflow/providers/apache/beam/operators/beam.py
@@ -364,11 +364,12 @@ class
BeamRunPythonPipelineOperator(BeamBasePipelineOperator):
def execute_sync(self, context: Context):
with ExitStack() as exit_stack:
- gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id)
if self.py_file.lower().startswith("gs://"):
+ gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id)
tmp_gcs_file =
exit_stack.enter_context(gcs_hook.provide_file(object_url=self.py_file))
self.py_file = tmp_gcs_file.name
if self.snake_case_pipeline_options.get("requirements_file",
"").startswith("gs://"):
+ gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id)
tmp_req_file = exit_stack.enter_context(
gcs_hook.provide_file(object_url=self.snake_case_pipeline_options["requirements_file"])
)
diff --git a/tests/providers/apache/beam/operators/test_beam.py
b/tests/providers/apache/beam/operators/test_beam.py
index f7ca9649fb..a6a4c31c77 100644
--- a/tests/providers/apache/beam/operators/test_beam.py
+++ b/tests/providers/apache/beam/operators/test_beam.py
@@ -256,6 +256,79 @@ class TestBeamRunPythonPipelineOperator:
op.on_kill()
dataflow_cancel_job.assert_not_called()
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ def test_execute_gcs_hook_not_called_without_gs_prefix(self,
mock_gcs_hook, _):
+ """
+ Test that execute method does not call GCSHook when neither py_file
nor requirements_file
+ starts with 'gs://'. (i.e., running pipeline entirely locally)
+ """
+ local_test_op_args = {
+ "task_id": TASK_ID,
+ "py_file": "local_file.py",
+ "py_options": ["-m"],
+ "default_pipeline_options": {
+ "project": TEST_PROJECT,
+ "requirements_file": "local_requirements.txt",
+ },
+ "pipeline_options": {"output": "test_local/output", "labels":
{"foo": "bar"}},
+ }
+
+ op = BeamRunPythonPipelineOperator(**local_test_op_args)
+ context_mock = mock.MagicMock()
+
+ op.execute(context_mock)
+ mock_gcs_hook.assert_not_called()
+
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ def test_execute_gcs_hook_called_with_gs_prefix_py_file(self,
mock_gcs_hook, _):
+ """
+ Test that execute method calls GCSHook when only 'py_file' starts with
'gs://'.
+ """
+ local_test_op_args = {
+ "task_id": TASK_ID,
+ "py_file": "gs://gcs_file.py",
+ "py_options": ["-m"],
+ "default_pipeline_options": {
+ "project": TEST_PROJECT,
+ "requirements_file": "local_requirements.txt",
+ },
+ "pipeline_options": {"output": "test_local/output", "labels":
{"foo": "bar"}},
+ }
+ op = BeamRunPythonPipelineOperator(**local_test_op_args)
+ context_mock = mock.MagicMock()
+
+ op.execute(context_mock)
+ mock_gcs_hook.assert_called_once()
+
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ def
test_execute_gcs_hook_called_with_gs_prefix_pipeline_requirements(self,
mock_gcs_hook, _):
+ """
+ Test that execute method calls GCSHook when only pipeline_options
'requirements_file' starts with
+ 'gs://'.
+ Note: "pipeline_options" is merged with and overrides keys in
"default_pipeline_options" when
+ BeamRunPythonPipelineOperator is instantiated, so testing GCS
'requirements_file' specified
+ in "pipeline_options"
+ """
+ local_test_op_args = {
+ "task_id": TASK_ID,
+ "py_file": "local_file.py",
+ "py_options": ["-m"],
+ "default_pipeline_options": {
+ "project": TEST_PROJECT,
+ "requirements_file": "gs://gcs_requirements.txt",
+ },
+ "pipeline_options": {"output": "test_local/output", "labels":
{"foo": "bar"}},
+ }
+
+ op = BeamRunPythonPipelineOperator(**local_test_op_args)
+ context_mock = mock.MagicMock()
+
+ op.execute(context_mock)
+ mock_gcs_hook.assert_called_once()
+
class TestBeamRunJavaPipelineOperator:
@pytest.fixture(autouse=True)