This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fd8a05739f Bugfix to correct GCSHook being called even when not 
required with BeamRunPythonPipelineOperator (#38716)
fd8a05739f is described below

commit fd8a05739f945643b5023db15d51a97459109a02
Author: Zack Strathe <59071005+zstra...@users.noreply.github.com>
AuthorDate: Fri Apr 19 03:40:19 2024 -0500

    Bugfix to correct GCSHook being called even when not required with 
BeamRunPythonPipelineOperator (#38716)
    
    * Bugfix to correct GCSHook being called even when not required with 
BeamRunPythonPipelineOperator
    
    * remove unneccary check for GCSHook and add unit test for 
BeamRunPythonPipelineOperator to ensure that GCSHook is only called when 
necessary
    
    * Split out unit tests for TestBeamRunPythonPipelineOperator with GCSHook 
'gs://' arg prefixes
    
    * Fix formatting
---
 airflow/providers/apache/beam/operators/beam.py    |  3 +-
 tests/providers/apache/beam/operators/test_beam.py | 73 ++++++++++++++++++++++
 2 files changed, 75 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/apache/beam/operators/beam.py 
b/airflow/providers/apache/beam/operators/beam.py
index e88923bc05..62f650f19a 100644
--- a/airflow/providers/apache/beam/operators/beam.py
+++ b/airflow/providers/apache/beam/operators/beam.py
@@ -364,11 +364,12 @@ class 
BeamRunPythonPipelineOperator(BeamBasePipelineOperator):
 
     def execute_sync(self, context: Context):
         with ExitStack() as exit_stack:
-            gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id)
             if self.py_file.lower().startswith("gs://"):
+                gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id)
                 tmp_gcs_file = 
exit_stack.enter_context(gcs_hook.provide_file(object_url=self.py_file))
                 self.py_file = tmp_gcs_file.name
             if self.snake_case_pipeline_options.get("requirements_file", 
"").startswith("gs://"):
+                gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id)
                 tmp_req_file = exit_stack.enter_context(
                     
gcs_hook.provide_file(object_url=self.snake_case_pipeline_options["requirements_file"])
                 )
diff --git a/tests/providers/apache/beam/operators/test_beam.py 
b/tests/providers/apache/beam/operators/test_beam.py
index f7ca9649fb..a6a4c31c77 100644
--- a/tests/providers/apache/beam/operators/test_beam.py
+++ b/tests/providers/apache/beam/operators/test_beam.py
@@ -256,6 +256,79 @@ class TestBeamRunPythonPipelineOperator:
         op.on_kill()
         dataflow_cancel_job.assert_not_called()
 
+    @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+    @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+    def test_execute_gcs_hook_not_called_without_gs_prefix(self, 
mock_gcs_hook, _):
+        """
+        Test that execute method does not call GCSHook when neither py_file 
nor requirements_file
+        starts with 'gs://'. (i.e., running pipeline entirely locally)
+        """
+        local_test_op_args = {
+            "task_id": TASK_ID,
+            "py_file": "local_file.py",
+            "py_options": ["-m"],
+            "default_pipeline_options": {
+                "project": TEST_PROJECT,
+                "requirements_file": "local_requirements.txt",
+            },
+            "pipeline_options": {"output": "test_local/output", "labels": 
{"foo": "bar"}},
+        }
+
+        op = BeamRunPythonPipelineOperator(**local_test_op_args)
+        context_mock = mock.MagicMock()
+
+        op.execute(context_mock)
+        mock_gcs_hook.assert_not_called()
+
+    @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+    @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+    def test_execute_gcs_hook_called_with_gs_prefix_py_file(self, 
mock_gcs_hook, _):
+        """
+        Test that execute method calls GCSHook when only 'py_file' starts with 
'gs://'.
+        """
+        local_test_op_args = {
+            "task_id": TASK_ID,
+            "py_file": "gs://gcs_file.py",
+            "py_options": ["-m"],
+            "default_pipeline_options": {
+                "project": TEST_PROJECT,
+                "requirements_file": "local_requirements.txt",
+            },
+            "pipeline_options": {"output": "test_local/output", "labels": 
{"foo": "bar"}},
+        }
+        op = BeamRunPythonPipelineOperator(**local_test_op_args)
+        context_mock = mock.MagicMock()
+
+        op.execute(context_mock)
+        mock_gcs_hook.assert_called_once()
+
+    @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+    @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+    def 
test_execute_gcs_hook_called_with_gs_prefix_pipeline_requirements(self, 
mock_gcs_hook, _):
+        """
+        Test that execute method calls GCSHook when only pipeline_options 
'requirements_file' starts with
+        'gs://'.
+        Note: "pipeline_options" is merged with and overrides keys in 
"default_pipeline_options" when
+              BeamRunPythonPipelineOperator is instantiated, so testing GCS 
'requirements_file' specified
+              in "pipeline_options"
+        """
+        local_test_op_args = {
+            "task_id": TASK_ID,
+            "py_file": "local_file.py",
+            "py_options": ["-m"],
+            "default_pipeline_options": {
+                "project": TEST_PROJECT,
+                "requirements_file": "gs://gcs_requirements.txt",
+            },
+            "pipeline_options": {"output": "test_local/output", "labels": 
{"foo": "bar"}},
+        }
+
+        op = BeamRunPythonPipelineOperator(**local_test_op_args)
+        context_mock = mock.MagicMock()
+
+        op.execute(context_mock)
+        mock_gcs_hook.assert_called_once()
+
 
 class TestBeamRunJavaPipelineOperator:
     @pytest.fixture(autouse=True)

Reply via email to