This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1489cf7a03 Fix deferrable mode for BeamRunJavaPipelineOperator (#39371)
1489cf7a03 is described below
commit 1489cf7a0372898ab5f905fa7b56f3b1327d2cfe
Author: Maksim <[email protected]>
AuthorDate: Tue May 14 07:53:13 2024 -0700
Fix deferrable mode for BeamRunJavaPipelineOperator (#39371)
---
airflow/providers/apache/beam/operators/beam.py | 21 +++------------------
airflow/providers/apache/beam/triggers/beam.py | 22 ++++++++++++++++++++--
tests/providers/apache/beam/operators/test_beam.py | 10 ++--------
tests/providers/apache/beam/triggers/test_beam.py | 13 +++++++++++++
4 files changed, 38 insertions(+), 28 deletions(-)
diff --git a/airflow/providers/apache/beam/operators/beam.py
b/airflow/providers/apache/beam/operators/beam.py
index 62f650f19a..af338cdc6d 100644
--- a/airflow/providers/apache/beam/operators/beam.py
+++ b/airflow/providers/apache/beam/operators/beam.py
@@ -546,7 +546,7 @@ class BeamRunJavaPipelineOperator(BeamBasePipelineOperator):
if not self.beam_hook:
raise AirflowException("Beam hook is not defined.")
if self.deferrable:
- asyncio.run(self.execute_async(context))
+ self.execute_async(context)
else:
return self.execute_sync(context)
@@ -605,23 +605,7 @@ class
BeamRunJavaPipelineOperator(BeamBasePipelineOperator):
process_line_callback=self.process_line_callback,
)
- async def execute_async(self, context: Context):
- # Creating a new event loop to manage I/O operations asynchronously
- loop = asyncio.get_event_loop()
- if self.jar.lower().startswith("gs://"):
- gcs_hook = GCSHook(self.gcp_conn_id)
- # Running synchronous `enter_context()` method in a separate
- # thread using the default executor `None`. The
`run_in_executor()` function returns the
- # file object, which is created using gcs function
`provide_file()`, asynchronously.
- # This means we can perform asynchronous operations with this file.
- create_tmp_file_call = gcs_hook.provide_file(object_url=self.jar)
- tmp_gcs_file: IO[str] = await loop.run_in_executor(
- None,
- contextlib.ExitStack().enter_context, # type: ignore[arg-type]
- create_tmp_file_call,
- )
- self.jar = tmp_gcs_file.name
-
+ def execute_async(self, context: Context):
if self.is_dataflow and self.dataflow_hook:
DataflowJobLink.persist(
self,
@@ -657,6 +641,7 @@ class BeamRunJavaPipelineOperator(BeamBasePipelineOperator):
job_class=self.job_class,
runner=self.runner,
check_if_running=self.dataflow_config.check_if_running ==
CheckJobRunning.WaitForRun,
+ gcp_conn_id=self.gcp_conn_id,
),
method_name="execute_complete",
)
diff --git a/airflow/providers/apache/beam/triggers/beam.py
b/airflow/providers/apache/beam/triggers/beam.py
index 5b1f7a99d5..b160218f73 100644
--- a/airflow/providers/apache/beam/triggers/beam.py
+++ b/airflow/providers/apache/beam/triggers/beam.py
@@ -17,7 +17,8 @@
from __future__ import annotations
import asyncio
-from typing import Any, AsyncIterator, Sequence
+import contextlib
+from typing import IO, Any, AsyncIterator, Sequence
from deprecated import deprecated
from google.cloud.dataflow_v1beta3 import ListJobsRequest
@@ -25,6 +26,7 @@ from google.cloud.dataflow_v1beta3 import ListJobsRequest
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook
from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook
+from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -166,7 +168,7 @@ class BeamJavaPipelineTrigger(BeamPipelineBaseTrigger):
project_id: str | None = None,
location: str | None = None,
job_name: str | None = None,
- gcp_conn_id: str | None = None,
+ gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
poll_sleep: int = 10,
cancel_timeout: int | None = None,
@@ -233,6 +235,22 @@ class BeamJavaPipelineTrigger(BeamPipelineBaseTrigger):
if is_running:
await asyncio.sleep(self.poll_sleep)
try:
+ # Get the current running event loop to manage I/O operations
asynchronously
+ loop = asyncio.get_running_loop()
+ if self.jar.lower().startswith("gs://"):
+ gcs_hook = GCSHook(self.gcp_conn_id)
+ # Running synchronous `enter_context()` method in a separate
+ # thread using the default executor `None`. The
`run_in_executor()` function returns the
+ # file object, which is created using gcs function
`provide_file()`, asynchronously.
+ # This means we can perform asynchronous operations with this
file.
+ create_tmp_file_call =
gcs_hook.provide_file(object_url=self.jar)
+ tmp_gcs_file: IO[str] = await loop.run_in_executor(
+ None,
+ contextlib.ExitStack().enter_context, # type:
ignore[arg-type]
+ create_tmp_file_call,
+ )
+ self.jar = tmp_gcs_file.name
+
return_code = await hook.start_java_pipeline_async(
variables=self.variables, jar=self.jar,
job_class=self.job_class
)
diff --git a/tests/providers/apache/beam/operators/test_beam.py
b/tests/providers/apache/beam/operators/test_beam.py
index a6a4c31c77..15d5c9778a 100644
--- a/tests/providers/apache/beam/operators/test_beam.py
+++ b/tests/providers/apache/beam/operators/test_beam.py
@@ -1013,24 +1013,20 @@ class TestBeamRunJavaPipelineOperatorAsync:
), "Trigger is not a BeamPJavaPipelineTrigger"
@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
- @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
- def test_async_execute_direct_runner(self, gcs_hook, beam_hook_mock):
+ def test_async_execute_direct_runner(self, beam_hook_mock):
"""
Test BeamHook is created and the right args are passed to
start_java_pipeline when executing direct runner.
"""
- gcs_provide_file = gcs_hook.return_value.provide_file
op = BeamRunJavaPipelineOperator(**self.default_op_kwargs)
with pytest.raises(TaskDeferred):
op.execute(context=mock.MagicMock())
beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
- gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
- @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
- def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock,
beam_hook_mock, persist_link_mock):
+ def test_exec_dataflow_runner(self, dataflow_hook_mock, beam_hook_mock,
persist_link_mock):
"""
Test DataflowHook is created and the right args are passed to
start_java_pipeline when executing Dataflow runner.
@@ -1039,7 +1035,6 @@ class TestBeamRunJavaPipelineOperatorAsync:
op = BeamRunJavaPipelineOperator(
runner="DataflowRunner", dataflow_config=dataflow_config,
**self.default_op_kwargs
)
- gcs_provide_file = gcs_hook.return_value.provide_file
magic_mock = mock.MagicMock()
with pytest.raises(TaskDeferred):
op.execute(context=magic_mock)
@@ -1062,7 +1057,6 @@ class TestBeamRunJavaPipelineOperatorAsync:
"region": "us-central1",
"impersonate_service_account": TEST_IMPERSONATION_ACCOUNT,
}
- gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
persist_link_mock.assert_called_once_with(
op,
magic_mock,
diff --git a/tests/providers/apache/beam/triggers/test_beam.py
b/tests/providers/apache/beam/triggers/test_beam.py
index 6bd1b4bc66..972e90161a 100644
--- a/tests/providers/apache/beam/triggers/test_beam.py
+++ b/tests/providers/apache/beam/triggers/test_beam.py
@@ -43,6 +43,7 @@ TEST_PY_REQUIREMENTS = ["apache-beam[gcp]==2.46.0"]
TEST_PY_PACKAGES = False
TEST_RUNNER = "DirectRunner"
TEST_JAR_FILE = "example.jar"
+TEST_GCS_JAR_FILE = "gs://my-bucket/example/test.jar"
TEST_JOB_CLASS = "TestClass"
TEST_CHECK_IF_RUNNING = False
TEST_JOB_NAME = "test_job_name"
@@ -215,3 +216,15 @@ class TestBeamJavaPipelineTrigger:
generator = java_trigger.run()
actual = await generator.asend(None)
assert TriggerEvent({"status": "error", "message": "Test exception"})
== actual
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.apache.beam.triggers.beam.GCSHook")
+ async def
test_beam_trigger_gcs_provide_file_should_execute_successfully(self, gcs_hook,
java_trigger):
+ """
+ Test that BeamJavaPipelineTrigger downloads GCS provide file correct.
+ """
+ gcs_provide_file = gcs_hook.return_value.provide_file
+ java_trigger.jar = TEST_GCS_JAR_FILE
+ generator = java_trigger.run()
+ await generator.asend(None)
+ gcs_provide_file.assert_called_once_with(object_url=TEST_GCS_JAR_FILE)