This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1489cf7a03 Fix deferrable mode for BeamRunJavaPipelineOperator (#39371)
1489cf7a03 is described below

commit 1489cf7a0372898ab5f905fa7b56f3b1327d2cfe
Author: Maksim <[email protected]>
AuthorDate: Tue May 14 07:53:13 2024 -0700

    Fix deferrable mode for BeamRunJavaPipelineOperator (#39371)
---
 airflow/providers/apache/beam/operators/beam.py    | 21 +++------------------
 airflow/providers/apache/beam/triggers/beam.py     | 22 ++++++++++++++++++++--
 tests/providers/apache/beam/operators/test_beam.py | 10 ++--------
 tests/providers/apache/beam/triggers/test_beam.py  | 13 +++++++++++++
 4 files changed, 38 insertions(+), 28 deletions(-)

diff --git a/airflow/providers/apache/beam/operators/beam.py 
b/airflow/providers/apache/beam/operators/beam.py
index 62f650f19a..af338cdc6d 100644
--- a/airflow/providers/apache/beam/operators/beam.py
+++ b/airflow/providers/apache/beam/operators/beam.py
@@ -546,7 +546,7 @@ class BeamRunJavaPipelineOperator(BeamBasePipelineOperator):
         if not self.beam_hook:
             raise AirflowException("Beam hook is not defined.")
         if self.deferrable:
-            asyncio.run(self.execute_async(context))
+            self.execute_async(context)
         else:
             return self.execute_sync(context)
 
@@ -605,23 +605,7 @@ class 
BeamRunJavaPipelineOperator(BeamBasePipelineOperator):
                     process_line_callback=self.process_line_callback,
                 )
 
-    async def execute_async(self, context: Context):
-        # Creating a new event loop to manage I/O operations asynchronously
-        loop = asyncio.get_event_loop()
-        if self.jar.lower().startswith("gs://"):
-            gcs_hook = GCSHook(self.gcp_conn_id)
-            # Running synchronous `enter_context()` method in a separate
-            # thread using the default executor `None`. The 
`run_in_executor()` function returns the
-            # file object, which is created using gcs function 
`provide_file()`, asynchronously.
-            # This means we can perform asynchronous operations with this file.
-            create_tmp_file_call = gcs_hook.provide_file(object_url=self.jar)
-            tmp_gcs_file: IO[str] = await loop.run_in_executor(
-                None,
-                contextlib.ExitStack().enter_context,  # type: ignore[arg-type]
-                create_tmp_file_call,
-            )
-            self.jar = tmp_gcs_file.name
-
+    def execute_async(self, context: Context):
         if self.is_dataflow and self.dataflow_hook:
             DataflowJobLink.persist(
                 self,
@@ -657,6 +641,7 @@ class BeamRunJavaPipelineOperator(BeamBasePipelineOperator):
                     job_class=self.job_class,
                     runner=self.runner,
                     check_if_running=self.dataflow_config.check_if_running == 
CheckJobRunning.WaitForRun,
+                    gcp_conn_id=self.gcp_conn_id,
                 ),
                 method_name="execute_complete",
             )
diff --git a/airflow/providers/apache/beam/triggers/beam.py 
b/airflow/providers/apache/beam/triggers/beam.py
index 5b1f7a99d5..b160218f73 100644
--- a/airflow/providers/apache/beam/triggers/beam.py
+++ b/airflow/providers/apache/beam/triggers/beam.py
@@ -17,7 +17,8 @@
 from __future__ import annotations
 
 import asyncio
-from typing import Any, AsyncIterator, Sequence
+import contextlib
+from typing import IO, Any, AsyncIterator, Sequence
 
 from deprecated import deprecated
 from google.cloud.dataflow_v1beta3 import ListJobsRequest
@@ -25,6 +26,7 @@ from google.cloud.dataflow_v1beta3 import ListJobsRequest
 from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook
 from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook
+from airflow.providers.google.cloud.hooks.gcs import GCSHook
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 
@@ -166,7 +168,7 @@ class BeamJavaPipelineTrigger(BeamPipelineBaseTrigger):
         project_id: str | None = None,
         location: str | None = None,
         job_name: str | None = None,
-        gcp_conn_id: str | None = None,
+        gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
         poll_sleep: int = 10,
         cancel_timeout: int | None = None,
@@ -233,6 +235,22 @@ class BeamJavaPipelineTrigger(BeamPipelineBaseTrigger):
                 if is_running:
                     await asyncio.sleep(self.poll_sleep)
         try:
+            # Get the current running event loop to manage I/O operations 
asynchronously
+            loop = asyncio.get_running_loop()
+            if self.jar.lower().startswith("gs://"):
+                gcs_hook = GCSHook(self.gcp_conn_id)
+                # Running synchronous `enter_context()` method in a separate
+                # thread using the default executor `None`. The 
`run_in_executor()` function returns the
+                # file object, which is created using gcs function 
`provide_file()`, asynchronously.
+                # This means we can perform asynchronous operations with this 
file.
+                create_tmp_file_call = 
gcs_hook.provide_file(object_url=self.jar)
+                tmp_gcs_file: IO[str] = await loop.run_in_executor(
+                    None,
+                    contextlib.ExitStack().enter_context,  # type: 
ignore[arg-type]
+                    create_tmp_file_call,
+                )
+                self.jar = tmp_gcs_file.name
+
             return_code = await hook.start_java_pipeline_async(
                 variables=self.variables, jar=self.jar, 
job_class=self.job_class
             )
diff --git a/tests/providers/apache/beam/operators/test_beam.py 
b/tests/providers/apache/beam/operators/test_beam.py
index a6a4c31c77..15d5c9778a 100644
--- a/tests/providers/apache/beam/operators/test_beam.py
+++ b/tests/providers/apache/beam/operators/test_beam.py
@@ -1013,24 +1013,20 @@ class TestBeamRunJavaPipelineOperatorAsync:
         ), "Trigger is not a BeamPJavaPipelineTrigger"
 
     @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
-    @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
-    def test_async_execute_direct_runner(self, gcs_hook, beam_hook_mock):
+    def test_async_execute_direct_runner(self, beam_hook_mock):
         """
         Test BeamHook is created and the right args are passed to
         start_java_pipeline when executing direct runner.
         """
-        gcs_provide_file = gcs_hook.return_value.provide_file
         op = BeamRunJavaPipelineOperator(**self.default_op_kwargs)
         with pytest.raises(TaskDeferred):
             op.execute(context=mock.MagicMock())
         beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
-        gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
 
     @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
     @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
     @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
-    @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
-    def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, 
beam_hook_mock, persist_link_mock):
+    def test_exec_dataflow_runner(self, dataflow_hook_mock, beam_hook_mock, 
persist_link_mock):
         """
         Test DataflowHook is created and the right args are passed to
         start_java_pipeline when executing Dataflow runner.
@@ -1039,7 +1035,6 @@ class TestBeamRunJavaPipelineOperatorAsync:
         op = BeamRunJavaPipelineOperator(
             runner="DataflowRunner", dataflow_config=dataflow_config, 
**self.default_op_kwargs
         )
-        gcs_provide_file = gcs_hook.return_value.provide_file
         magic_mock = mock.MagicMock()
         with pytest.raises(TaskDeferred):
             op.execute(context=magic_mock)
@@ -1062,7 +1057,6 @@ class TestBeamRunJavaPipelineOperatorAsync:
             "region": "us-central1",
             "impersonate_service_account": TEST_IMPERSONATION_ACCOUNT,
         }
-        gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
         persist_link_mock.assert_called_once_with(
             op,
             magic_mock,
diff --git a/tests/providers/apache/beam/triggers/test_beam.py 
b/tests/providers/apache/beam/triggers/test_beam.py
index 6bd1b4bc66..972e90161a 100644
--- a/tests/providers/apache/beam/triggers/test_beam.py
+++ b/tests/providers/apache/beam/triggers/test_beam.py
@@ -43,6 +43,7 @@ TEST_PY_REQUIREMENTS = ["apache-beam[gcp]==2.46.0"]
 TEST_PY_PACKAGES = False
 TEST_RUNNER = "DirectRunner"
 TEST_JAR_FILE = "example.jar"
+TEST_GCS_JAR_FILE = "gs://my-bucket/example/test.jar"
 TEST_JOB_CLASS = "TestClass"
 TEST_CHECK_IF_RUNNING = False
 TEST_JOB_NAME = "test_job_name"
@@ -215,3 +216,15 @@ class TestBeamJavaPipelineTrigger:
         generator = java_trigger.run()
         actual = await generator.asend(None)
         assert TriggerEvent({"status": "error", "message": "Test exception"}) 
== actual
+
+    @pytest.mark.asyncio
+    @mock.patch("airflow.providers.apache.beam.triggers.beam.GCSHook")
+    async def 
test_beam_trigger_gcs_provide_file_should_execute_successfully(self, gcs_hook, 
java_trigger):
+        """
+        Test that BeamJavaPipelineTrigger downloads GCS provide file correct.
+        """
+        gcs_provide_file = gcs_hook.return_value.provide_file
+        java_trigger.jar = TEST_GCS_JAR_FILE
+        generator = java_trigger.run()
+        await generator.asend(None)
+        gcs_provide_file.assert_called_once_with(object_url=TEST_GCS_JAR_FILE)

Reply via email to