This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8c4303e1ac Add support for running a Beam Go pipeline with an
executable binary (#28764)
8c4303e1ac is described below
commit 8c4303e1ace0774244b556a8d86a19058af2b16d
Author: Johanna Öjeling <[email protected]>
AuthorDate: Wed Jan 18 18:52:05 2023 +0100
Add support for running a Beam Go pipeline with an executable binary
(#28764)
The `BeamRunGoPipelineOperator` currently has a `go_file` parameter, which
represents the path to a Go source file with the pipeline code. The operator
starts the pipeline with `go run`, i.e. compiles the code into a temporary
binary and executes.
This PR adds support for the operator to start the pipeline with an already
compiled binary, as an alternative to the source file approach. It introduces
two new parameters:
1. `launcher_binary` path to a binary compiled for the launching platform,
i.e. the platform where Airflow is deployed
2. `worker_binary` (optional) path to a binary compiled for the worker
platform if using a remote runner
Some motivations to introduce this feature:
- It does not require a Go installation on the system where Airflow is run
(which is more similar to how the `BeamRunJavaPipelineOperator` works, running
a jar)
- It does not involve the extra steps of initializing a Go module,
installing dependences and compiling the code every task run, which is what
currently happens when the Go source file is downloaded from GCS
- In the current implementation only a single Go source file can downloaded
from GCS. This can be limiting if the project comprises multiple files
---
airflow/providers/apache/beam/hooks/beam.py | 35 ++-
airflow/providers/apache/beam/operators/beam.py | 179 +++++++++++--
docs/spelling_wordlist.txt | 1 +
tests/providers/apache/beam/hooks/test_beam.py | 29 ++
tests/providers/apache/beam/operators/test_beam.py | 295 +++++++++++++++++++--
5 files changed, 497 insertions(+), 42 deletions(-)
diff --git a/airflow/providers/apache/beam/hooks/beam.py
b/airflow/providers/apache/beam/hooks/beam.py
index 28a5abc0c6..c318d17363 100644
--- a/airflow/providers/apache/beam/hooks/beam.py
+++ b/airflow/providers/apache/beam/hooks/beam.py
@@ -19,6 +19,7 @@
from __future__ import annotations
import contextlib
+import copy
import json
import os
import select
@@ -310,11 +311,10 @@ class BeamHook(BaseHook):
should_init_module: bool = False,
) -> None:
"""
- Starts Apache Beam Go pipeline.
+ Starts Apache Beam Go pipeline with a source file.
:param variables: Variables passed to the job.
:param go_file: Path to the Go file with your beam pipeline.
- :param go_file:
:param process_line_callback: (optional) Callback that can be used to
process each line of
the stdout and stderr file descriptors.
:param should_init_module: If False (default), will just execute a `go
run` command. If True, will
@@ -346,3 +346,34 @@ class BeamHook(BaseHook):
process_line_callback=process_line_callback,
working_directory=working_directory,
)
+
+ def start_go_pipeline_with_binary(
+ self,
+ variables: dict,
+ launcher_binary: str,
+ worker_binary: str,
+ process_line_callback: Callable[[str], None] | None = None,
+ ) -> None:
+ """
+ Starts Apache Beam Go pipeline with an executable binary.
+
+ :param variables: Variables passed to the job.
+ :param launcher_binary: Path to the binary compiled for the launching
platform.
+ :param worker_binary: Path to the binary compiled for the worker
platform.
+ :param process_line_callback: (optional) Callback that can be used to
process each line of
+ the stdout and stderr file descriptors.
+ """
+ job_variables = copy.deepcopy(variables)
+
+ if "labels" in job_variables:
+ job_variables["labels"] = json.dumps(job_variables["labels"],
separators=(",", ":"))
+
+ job_variables["worker_binary"] = worker_binary
+
+ command_prefix = [launcher_binary]
+
+ self._start_pipeline(
+ variables=job_variables,
+ command_prefix=command_prefix,
+ process_line_callback=process_line_callback,
+ )
diff --git a/airflow/providers/apache/beam/operators/beam.py
b/airflow/providers/apache/beam/operators/beam.py
index e9395a8dde..b041dbdc24 100644
--- a/airflow/providers/apache/beam/operators/beam.py
+++ b/airflow/providers/apache/beam/operators/beam.py
@@ -19,9 +19,13 @@
from __future__ import annotations
import copy
+import os
+import stat
import tempfile
-from abc import ABC, ABCMeta
+from abc import ABC, ABCMeta, abstractmethod
+from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack
+from functools import partial
from typing import TYPE_CHECKING, Callable, Sequence
from airflow import AirflowException
@@ -31,10 +35,10 @@ from airflow.providers.google.cloud.hooks.dataflow import (
DataflowHook,
process_line_and_extract_dataflow_job_id_callback,
)
-from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
from airflow.providers.google.cloud.links.dataflow import DataflowJobLink
from airflow.providers.google.cloud.operators.dataflow import CheckJobRunning,
DataflowConfiguration
-from airflow.utils.helpers import convert_camel_to_snake
+from airflow.utils.helpers import convert_camel_to_snake, exactly_one
from airflow.version import version
if TYPE_CHECKING:
@@ -520,12 +524,27 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
For more detail on Apache Beam have a look at the reference:
https://beam.apache.org/documentation/
- :param go_file: Reference to the Go Apache Beam pipeline e.g.,
- /some/local/file/path/to/your/go/pipeline/file.go
+ :param go_file: Reference to the Apache Beam pipeline Go source file,
+ e.g. /local/path/to/main.go or gs://bucket/path/to/main.go.
+ Exactly one of go_file and launcher_binary must be provided.
+
+ :param launcher_binary: Reference to the Apache Beam pipeline Go binary
compiled for the launching
+ platform, e.g. /local/path/to/launcher-main or
gs://bucket/path/to/launcher-main.
+ Exactly one of go_file and launcher_binary must be provided.
+
+ :param worker_binary: Reference to the Apache Beam pipeline Go binary
compiled for the worker platform,
+ e.g. /local/path/to/worker-main or gs://bucket/path/to/worker-main.
+ Needed if the OS or architecture of the workers running the pipeline
is different from that
+ of the platform launching the pipeline. For more information, see the
Apache Beam documentation
+ for Go cross compilation:
https://beam.apache.org/documentation/sdks/go-cross-compilation/.
+ If launcher_binary is not set, providing a worker_binary will have no
effect. If launcher_binary is
+ set and worker_binary is not, worker_binary will default to the value
of launcher_binary.
"""
template_fields = [
"go_file",
+ "launcher_binary",
+ "worker_binary",
"runner",
"pipeline_options",
"default_pipeline_options",
@@ -537,7 +556,9 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
def __init__(
self,
*,
- go_file: str,
+ go_file: str = "",
+ launcher_binary: str = "",
+ worker_binary: str = "",
runner: str = "DirectRunner",
default_pipeline_options: dict | None = None,
pipeline_options: dict | None = None,
@@ -563,8 +584,13 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
)
self.dataflow_support_impersonation = False
+ if not exactly_one(go_file, launcher_binary):
+ raise ValueError("Exactly one of `go_file` and `launcher_binary`
must be set")
+
self.go_file = go_file
- self.should_init_go_module = False
+ self.launcher_binary = launcher_binary
+ self.worker_binary = worker_binary or launcher_binary
+
self.pipeline_options.setdefault("labels", {}).update(
{"airflow-version": "v" + version.replace(".", "-").replace("+",
"-")}
)
@@ -581,24 +607,24 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
if not self.beam_hook:
raise AirflowException("Beam hook is not defined.")
+ go_artifact: _GoArtifact = (
+ _GoFile(file=self.go_file)
+ if self.go_file
+ else _GoBinary(launcher=self.launcher_binary,
worker=self.worker_binary)
+ )
+
with ExitStack() as exit_stack:
- if self.go_file.lower().startswith("gs://"):
+ if go_artifact.is_located_on_gcs():
gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
-
tmp_dir =
exit_stack.enter_context(tempfile.TemporaryDirectory(prefix="apache-beam-go"))
- tmp_gcs_file = exit_stack.enter_context(
- gcs_hook.provide_file(object_url=self.go_file, dir=tmp_dir)
- )
- self.go_file = tmp_gcs_file.name
- self.should_init_go_module = True
+ go_artifact.download_from_gcs(gcs_hook=gcs_hook,
tmp_dir=tmp_dir)
if is_dataflow and self.dataflow_hook:
with self.dataflow_hook.provide_authorized_gcloud():
- self.beam_hook.start_go_pipeline(
+ go_artifact.start_pipeline(
+ beam_hook=self.beam_hook,
variables=snake_case_pipeline_options,
- go_file=self.go_file,
process_line_callback=process_line_callback,
- should_init_module=self.should_init_go_module,
)
DataflowJobLink.persist(
@@ -618,11 +644,10 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
)
return {"dataflow_job_id": self.dataflow_job_id}
else:
- self.beam_hook.start_go_pipeline(
+ go_artifact.start_pipeline(
+ beam_hook=self.beam_hook,
variables=snake_case_pipeline_options,
- go_file=self.go_file,
process_line_callback=process_line_callback,
- should_init_module=self.should_init_go_module,
)
def on_kill(self) -> None:
@@ -632,3 +657,117 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
job_id=self.dataflow_job_id,
project_id=self.dataflow_config.project_id,
)
+
+
+class _GoArtifact(ABC):
+ @abstractmethod
+ def is_located_on_gcs(self) -> bool:
+ ...
+
+ @abstractmethod
+ def download_from_gcs(self, gcs_hook: GCSHook, tmp_dir: str) -> None:
+ ...
+
+ @abstractmethod
+ def start_pipeline(
+ self,
+ beam_hook: BeamHook,
+ variables: dict,
+ process_line_callback: Callable[[str], None] | None = None,
+ ) -> None:
+ ...
+
+
+class _GoFile(_GoArtifact):
+ def __init__(self, file: str) -> None:
+ self.file = file
+ self.should_init_go_module = False
+
+ def is_located_on_gcs(self) -> bool:
+ return _object_is_located_on_gcs(self.file)
+
+ def download_from_gcs(self, gcs_hook: GCSHook, tmp_dir: str) -> None:
+ self.file = _download_object_from_gcs(gcs_hook=gcs_hook,
uri=self.file, tmp_dir=tmp_dir)
+ self.should_init_go_module = True
+
+ def start_pipeline(
+ self,
+ beam_hook: BeamHook,
+ variables: dict,
+ process_line_callback: Callable[[str], None] | None = None,
+ ) -> None:
+ beam_hook.start_go_pipeline(
+ variables=variables,
+ go_file=self.file,
+ process_line_callback=process_line_callback,
+ should_init_module=self.should_init_go_module,
+ )
+
+
+class _GoBinary(_GoArtifact):
+ def __init__(self, launcher: str, worker: str) -> None:
+ self.launcher = launcher
+ self.worker = worker
+
+ def is_located_on_gcs(self) -> bool:
+ return any(_object_is_located_on_gcs(path) for path in (self.launcher,
self.worker))
+
+ def download_from_gcs(self, gcs_hook: GCSHook, tmp_dir: str) -> None:
+ binaries_are_equal = self.launcher == self.worker
+
+ binaries_to_download = []
+
+ if _object_is_located_on_gcs(self.launcher):
+ binaries_to_download.append("launcher")
+
+ if not binaries_are_equal and _object_is_located_on_gcs(self.worker):
+ binaries_to_download.append("worker")
+
+ download_fn = partial(_download_object_from_gcs, gcs_hook=gcs_hook,
tmp_dir=tmp_dir)
+
+ with ThreadPoolExecutor(max_workers=len(binaries_to_download)) as
executor:
+ futures = {
+ executor.submit(download_fn, uri=getattr(self, binary),
tmp_prefix=f"{binary}-"): binary
+ for binary in binaries_to_download
+ }
+
+ for future in as_completed(futures):
+ binary = futures[future]
+ tmp_path = future.result()
+ _make_executable(tmp_path)
+ setattr(self, binary, tmp_path)
+
+ if binaries_are_equal:
+ self.worker = self.launcher
+
+ def start_pipeline(
+ self,
+ beam_hook: BeamHook,
+ variables: dict,
+ process_line_callback: Callable[[str], None] | None = None,
+ ) -> None:
+ beam_hook.start_go_pipeline_with_binary(
+ variables=variables,
+ launcher_binary=self.launcher,
+ worker_binary=self.worker,
+ process_line_callback=process_line_callback,
+ )
+
+
+def _object_is_located_on_gcs(path: str) -> bool:
+ return path.lower().startswith("gs://")
+
+
+def _download_object_from_gcs(gcs_hook: GCSHook, uri: str, tmp_dir: str,
tmp_prefix: str = "") -> str:
+ tmp_name = f"{tmp_prefix}{os.path.basename(uri)}"
+ tmp_path = os.path.join(tmp_dir, tmp_name)
+
+ bucket, prefix = _parse_gcs_url(uri)
+ gcs_hook.download(bucket_name=bucket, object_name=prefix,
filename=tmp_path)
+
+ return tmp_path
+
+
+def _make_executable(path: str) -> None:
+ st = os.stat(path)
+ os.chmod(path, st.st_mode | stat.S_IEXEC)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 68b19e1192..28a34b76dc 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1269,6 +1269,7 @@ schedulable
schedulername
schemas
sdk
+sdks
searchpath
SearchResultGenerator
SecretManagerClient
diff --git a/tests/providers/apache/beam/hooks/test_beam.py
b/tests/providers/apache/beam/hooks/test_beam.py
index f509c18b0a..80cf26687d 100644
--- a/tests/providers/apache/beam/hooks/test_beam.py
+++ b/tests/providers/apache/beam/hooks/test_beam.py
@@ -312,6 +312,35 @@ class TestBeamHook:
variables=copy.deepcopy(BEAM_VARIABLES_GO),
)
+ @mock.patch(BEAM_STRING.format("BeamCommandRunner"))
+ def test_start_go_pipeline_with_binary(self, mock_runner):
+ hook = BeamHook(runner=DEFAULT_RUNNER)
+ wait_for_done_method = mock_runner.return_value.wait_for_done
+ process_line_callback = MagicMock()
+
+ launcher_binary = "/path/to/launcher-main"
+ worker_binary = "/path/to/worker-main"
+
+ hook.start_go_pipeline_with_binary(
+ variables=BEAM_VARIABLES_GO,
+ launcher_binary=launcher_binary,
+ worker_binary=worker_binary,
+ process_line_callback=process_line_callback,
+ )
+
+ expected_cmd = [
+ launcher_binary,
+ f"--runner={DEFAULT_RUNNER}",
+ "--output=gs://test/output",
+ '--labels={"foo":"bar"}',
+ f"--worker_binary={worker_binary}",
+ ]
+
+ mock_runner.assert_called_once_with(
+ cmd=expected_cmd, process_line_callback=process_line_callback,
working_directory=None
+ )
+ wait_for_done_method.assert_called_once_with()
+
class TestBeamRunner:
@mock.patch("airflow.providers.apache.beam.hooks.beam.BeamCommandRunner.log")
diff --git a/tests/providers/apache/beam/operators/test_beam.py
b/tests/providers/apache/beam/operators/test_beam.py
index d5e0bfd58d..4cefb756b4 100644
--- a/tests/providers/apache/beam/operators/test_beam.py
+++ b/tests/providers/apache/beam/operators/test_beam.py
@@ -16,8 +16,11 @@
# under the License.
from __future__ import annotations
+import os
from unittest import mock
-from unittest.mock import MagicMock
+from unittest.mock import MagicMock, call
+
+import pytest
from airflow.providers.apache.beam.operators.beam import (
BeamRunGoPipelineOperator,
@@ -37,7 +40,9 @@ PY_FILE = "gs://my-bucket/my-object.py"
PY_INTERPRETER = "python3"
PY_OPTIONS = ["-m"]
GO_FILE = "gs://my-bucket/example/main.go"
-DEFAULT_OPTIONS_PYTHON = DEFAULT_OPTIONS_JAVA = {
+LAUNCHER_BINARY = "gs://my-bucket/example/launcher"
+WORKER_BINARY = "gs://my-bucket/example/worker"
+DEFAULT_OPTIONS = {
"project": "test",
"stagingLocation": "gs://test/staging",
}
@@ -56,7 +61,7 @@ class TestBeamRunPythonPipelineOperator:
task_id=TASK_ID,
py_file=PY_FILE,
py_options=PY_OPTIONS,
- default_pipeline_options=DEFAULT_OPTIONS_PYTHON,
+ default_pipeline_options=DEFAULT_OPTIONS,
pipeline_options=ADDITIONAL_OPTIONS,
)
@@ -67,7 +72,7 @@ class TestBeamRunPythonPipelineOperator:
assert self.operator.runner == DEFAULT_RUNNER
assert self.operator.py_options == PY_OPTIONS
assert self.operator.py_interpreter == PY_INTERPRETER
- assert self.operator.default_pipeline_options == DEFAULT_OPTIONS_PYTHON
+ assert self.operator.default_pipeline_options == DEFAULT_OPTIONS
assert self.operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS
@mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
@@ -185,7 +190,7 @@ class TestBeamRunJavaPipelineOperator:
task_id=TASK_ID,
jar=JAR_FILE,
job_class=JOB_CLASS,
- default_pipeline_options=DEFAULT_OPTIONS_JAVA,
+ default_pipeline_options=DEFAULT_OPTIONS,
pipeline_options=ADDITIONAL_OPTIONS,
)
@@ -193,7 +198,7 @@ class TestBeamRunJavaPipelineOperator:
"""Test BeamRunJavaPipelineOperator instance is properly
initialized."""
assert self.operator.task_id == TASK_ID
assert self.operator.runner == DEFAULT_RUNNER
- assert self.operator.default_pipeline_options == DEFAULT_OPTIONS_JAVA
+ assert self.operator.default_pipeline_options == DEFAULT_OPTIONS
assert self.operator.job_class == JOB_CLASS
assert self.operator.jar == JAR_FILE
assert self.operator.pipeline_options == ADDITIONAL_OPTIONS
@@ -211,7 +216,7 @@ class TestBeamRunJavaPipelineOperator:
beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
start_java_hook.assert_called_once_with(
- variables={**DEFAULT_OPTIONS_JAVA, **ADDITIONAL_OPTIONS},
+ variables={**DEFAULT_OPTIONS, **ADDITIONAL_OPTIONS},
jar=gcs_provide_file.return_value.__enter__.return_value.name,
job_class=JOB_CLASS,
process_line_callback=None,
@@ -303,30 +308,96 @@ class TestBeamRunGoPipelineOperator:
self.operator = BeamRunGoPipelineOperator(
task_id=TASK_ID,
go_file=GO_FILE,
- default_pipeline_options=DEFAULT_OPTIONS_PYTHON,
+ default_pipeline_options=DEFAULT_OPTIONS,
pipeline_options=ADDITIONAL_OPTIONS,
)
- def test_init(self):
- """Test BeamRunGoPipelineOperator instance is properly initialized."""
+ def test_init_with_go_file(self):
+ """Test BeamRunGoPipelineOperator instance is properly initialized
with go_file."""
assert self.operator.task_id == TASK_ID
assert self.operator.go_file == GO_FILE
+ assert self.operator.launcher_binary == ""
+ assert self.operator.worker_binary == ""
assert self.operator.runner == DEFAULT_RUNNER
- assert self.operator.default_pipeline_options == DEFAULT_OPTIONS_PYTHON
+ assert self.operator.default_pipeline_options == DEFAULT_OPTIONS
assert self.operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS
+ def test_init_with_launcher_binary(self):
+ """Test BeamRunGoPipelineOperator instance is properly initialized
with launcher_binary."""
+ operator = BeamRunGoPipelineOperator(
+ task_id=TASK_ID,
+ launcher_binary=LAUNCHER_BINARY,
+ default_pipeline_options=DEFAULT_OPTIONS,
+ pipeline_options=ADDITIONAL_OPTIONS,
+ )
+
+ assert operator.task_id == TASK_ID
+ assert operator.go_file == ""
+ assert operator.launcher_binary == LAUNCHER_BINARY
+ assert operator.worker_binary == LAUNCHER_BINARY
+ assert operator.runner == DEFAULT_RUNNER
+ assert operator.default_pipeline_options == DEFAULT_OPTIONS
+ assert operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS
+
+ def test_init_with_launcher_binary_and_worker_binary(self):
+ """
+ Test BeamRunGoPipelineOperator instance is properly initialized with
launcher_binary and
+ worker_binary.
+ """
+ operator = BeamRunGoPipelineOperator(
+ task_id=TASK_ID,
+ launcher_binary=LAUNCHER_BINARY,
+ worker_binary=WORKER_BINARY,
+ default_pipeline_options=DEFAULT_OPTIONS,
+ pipeline_options=ADDITIONAL_OPTIONS,
+ )
+
+ assert operator.task_id == TASK_ID
+ assert operator.go_file == ""
+ assert operator.launcher_binary == LAUNCHER_BINARY
+ assert operator.worker_binary == WORKER_BINARY
+ assert operator.runner == DEFAULT_RUNNER
+ assert operator.default_pipeline_options == DEFAULT_OPTIONS
+ assert operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS
+
+ def test_init_with_neither_go_file_nor_launcher_binary_raises(self):
+ """
+ Test BeamRunGoPipelineOperator initialization raises ValueError when
neither
+ go_file nor launcher_binary is provided.
+ """
+ with pytest.raises(ValueError, match="Exactly one of `go_file` and
`launcher_binary` must be set"):
+ BeamRunGoPipelineOperator(
+ task_id=TASK_ID,
+ default_pipeline_options=DEFAULT_OPTIONS,
+ pipeline_options=ADDITIONAL_OPTIONS,
+ )
+
+ def test_init_with_both_go_file_and_launcher_binary_raises(self):
+ """
+ Test BeamRunGoPipelineOperator initialization raises ValueError when
both of
+ go_file and launcher_binary are provided.
+ """
+ with pytest.raises(ValueError, match="Exactly one of `go_file` and
`launcher_binary` must be set"):
+ BeamRunGoPipelineOperator(
+ task_id=TASK_ID,
+ go_file=GO_FILE,
+ launcher_binary=LAUNCHER_BINARY,
+ default_pipeline_options=DEFAULT_OPTIONS,
+ pipeline_options=ADDITIONAL_OPTIONS,
+ )
+
@mock.patch(
"tempfile.TemporaryDirectory",
return_value=MagicMock(__enter__=MagicMock(return_value="/tmp/apache-beam-go")),
)
@mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
@mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
- def test_exec_direct_runner(self, gcs_hook, beam_hook_mock, _):
+ def test_exec_direct_runner_with_gcs_go_file(self, gcs_hook,
beam_hook_mock, _):
"""Test BeamHook is created and the right args are passed to
start_go_workflow.
"""
start_go_pipeline_method =
beam_hook_mock.return_value.start_go_pipeline
- gcs_provide_file_method = gcs_hook.return_value.provide_file
+ gcs_download_method = gcs_hook.return_value.download
self.operator.execute(None)
beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
expected_options = {
@@ -335,17 +406,75 @@ class TestBeamRunGoPipelineOperator:
"output": "gs://test/output",
"labels": {"foo": "bar", "airflow-version": TEST_VERSION},
}
- gcs_provide_file_method.assert_called_once_with(object_url=GO_FILE,
dir="/tmp/apache-beam-go")
+ expected_go_file = "/tmp/apache-beam-go/main.go"
+ gcs_download_method.assert_called_once_with(
+ bucket_name="my-bucket", object_name="example/main.go",
filename=expected_go_file
+ )
start_go_pipeline_method.assert_called_once_with(
variables=expected_options,
-
go_file=gcs_provide_file_method.return_value.__enter__.return_value.name,
+ go_file=expected_go_file,
process_line_callback=None,
should_init_module=True,
)
+ @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
+ @mock.patch("tempfile.TemporaryDirectory")
+ def test_exec_direct_runner_with_gcs_launcher_binary(
+ self, mock_tmp_dir, mock_beam_hook, mock_gcs_hook, tmp_path
+ ):
+ """
+ Test start_go_pipeline_from_binary is called with an executable
launcher binary downloaded from GCS.
+ """
+
+ def tmp_dir_side_effect(prefix: str) -> str:
+ sub_dir = tmp_path / mock_tmp_dir.call_args[1]["prefix"]
+ sub_dir.mkdir()
+ return str(sub_dir)
+
+ mock_tmp_dir.return_value.__enter__.side_effect = tmp_dir_side_effect
+
+ def gcs_download_side_effect(bucket_name: str, object_name: str,
filename: str) -> None:
+ open(filename, "wb").close()
+
+ gcs_download_method = mock_gcs_hook.return_value.download
+ gcs_download_method.side_effect = gcs_download_side_effect
+
+ start_go_pipeline_method =
mock_beam_hook.return_value.start_go_pipeline_with_binary
+
+ operator = BeamRunGoPipelineOperator(
+ task_id=TASK_ID,
+ launcher_binary="gs://bucket/path/to/main",
+ default_pipeline_options=DEFAULT_OPTIONS,
+ pipeline_options=ADDITIONAL_OPTIONS,
+ )
+ operator.execute({})
+
+ expected_binary = f"{tmp_path}/apache-beam-go/launcher-main"
+ expected_options = {
+ "project": "test",
+ "staging_location": "gs://test/staging",
+ "output": "gs://test/output",
+ "labels": {"foo": "bar", "airflow-version": TEST_VERSION},
+ }
+ mock_beam_hook.assert_called_once_with(runner=DEFAULT_RUNNER)
+ mock_tmp_dir.assert_called_once_with(prefix="apache-beam-go")
+ gcs_download_method.assert_called_once_with(
+ bucket_name="bucket",
+ object_name="path/to/main",
+ filename=expected_binary,
+ )
+ assert os.access(expected_binary, os.X_OK)
+ start_go_pipeline_method.assert_called_once_with(
+ variables=expected_options,
+ launcher_binary=expected_binary,
+ worker_binary=expected_binary,
+ process_line_callback=None,
+ )
+
@mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
@mock.patch("airflow.providers.google.go_module_utils.init_module")
- def test_exec_source_on_local_path(self, init_module, beam_hook_mock):
+ def test_exec_direct_runner_with_local_go_file(self, init_module,
beam_hook_mock):
"""
Check that start_go_pipeline is called without initializing the Go
module when source is locale.
"""
@@ -365,6 +494,29 @@ class TestBeamRunGoPipelineOperator:
should_init_module=False,
)
+ @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
+ def test_exec_direct_runner_with_local_launcher_binary(self,
mock_beam_hook):
+ """
+ Test start_go_pipeline_with_binary is called with a local launcher
binary.
+ """
+ start_go_pipeline_method =
mock_beam_hook.return_value.start_go_pipeline_with_binary
+
+ operator = BeamRunGoPipelineOperator(
+ task_id=TASK_ID,
+ launcher_binary="/local/path/to/main",
+ )
+ operator.execute({})
+
+ expected_binary = "/local/path/to/main"
+
+ mock_beam_hook.assert_called_once_with(runner=DEFAULT_RUNNER)
+ start_go_pipeline_method.assert_called_once_with(
+ variables={"labels": {"airflow-version": TEST_VERSION}},
+ launcher_binary=expected_binary,
+ worker_binary=expected_binary,
+ process_line_callback=None,
+ )
+
@mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
@mock.patch(
"tempfile.TemporaryDirectory",
@@ -373,14 +525,16 @@ class TestBeamRunGoPipelineOperator:
@mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
@mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
@mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
- def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock,
beam_hook_mock, _, persist_link_mock):
+ def test_exec_dataflow_runner_with_go_file(
+ self, gcs_hook, dataflow_hook_mock, beam_hook_mock, _,
persist_link_mock
+ ):
"""Test DataflowHook is created and the right args are passed to
start_go_dataflow.
"""
dataflow_config =
DataflowConfiguration(impersonation_chain="[email protected]")
self.operator.runner = "DataflowRunner"
self.operator.dataflow_config = dataflow_config
- gcs_provide_file = gcs_hook.return_value.provide_file
+ gcs_download_method = gcs_hook.return_value.download
self.operator.execute(None)
job_name = dataflow_hook_mock.build_dataflow_job_name.return_value
dataflow_hook_mock.assert_called_once_with(
@@ -407,10 +561,13 @@ class TestBeamRunGoPipelineOperator:
expected_options["region"],
self.operator.dataflow_job_id,
)
- gcs_provide_file.assert_called_once_with(object_url=GO_FILE,
dir="/tmp/apache-beam-go")
+ expected_go_file = "/tmp/apache-beam-go/main.go"
+ gcs_download_method.assert_called_once_with(
+ bucket_name="my-bucket", object_name="example/main.go",
filename=expected_go_file
+ )
beam_hook_mock.return_value.start_go_pipeline.assert_called_once_with(
variables=expected_options,
- go_file=gcs_provide_file.return_value.__enter__.return_value.name,
+ go_file=expected_go_file,
process_line_callback=mock.ANY,
should_init_module=True,
)
@@ -423,6 +580,104 @@ class TestBeamRunGoPipelineOperator:
)
dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with()
+
@mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
+ @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
+ @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
+ @mock.patch("tempfile.TemporaryDirectory")
+ def test_exec_dataflow_runner_with_launcher_binary_and_worker_binary(
+ self, mock_tmp_dir, mock_beam_hook, mock_gcs_hook, mock_dataflow_hook,
mock_persist_link, tmp_path
+ ):
+ """
+ Test DataflowHook is created and start_go_pipeline_from_binary is
called with
+ a launcher binary and a worker binary.
+ """
+
+ def tmp_dir_side_effect(prefix: str) -> str:
+ sub_dir = tmp_path / mock_tmp_dir.call_args[1]["prefix"]
+ sub_dir.mkdir()
+ return str(sub_dir)
+
+ mock_tmp_dir.return_value.__enter__.side_effect = tmp_dir_side_effect
+
+ def gcs_download_side_effect(bucket_name: str, object_name: str,
filename: str) -> None:
+ open(filename, "wb").close()
+
+ gcs_download_method = mock_gcs_hook.return_value.download
+ gcs_download_method.side_effect = gcs_download_side_effect
+
+ mock_dataflow_hook.build_dataflow_job_name.return_value = "test-job"
+
+ provide_authorized_gcloud_method =
mock_dataflow_hook.return_value.provide_authorized_gcloud
+ start_go_pipeline_method =
mock_beam_hook.return_value.start_go_pipeline_with_binary
+ wait_for_done_method = mock_dataflow_hook.return_value.wait_for_done
+
+ dataflow_config = DataflowConfiguration(project_id="test-project")
+
+ operator = BeamRunGoPipelineOperator(
+ task_id=TASK_ID,
+ launcher_binary="gs://bucket/path/to/main1",
+ worker_binary="gs://bucket/path/to/main2",
+ runner="DataflowRunner",
+ default_pipeline_options=DEFAULT_OPTIONS,
+ pipeline_options=ADDITIONAL_OPTIONS,
+ dataflow_config=dataflow_config,
+ )
+ operator.execute({})
+
+ expected_launcher_binary = str(tmp_path /
"apache-beam-go/launcher-main1")
+ expected_worker_binary = str(tmp_path / "apache-beam-go/worker-main2")
+ expected_job_name = "test-job"
+ expected_options = {
+ "project": "test-project",
+ "job_name": expected_job_name,
+ "staging_location": "gs://test/staging",
+ "output": "gs://test/output",
+ "labels": {"foo": "bar", "airflow-version": TEST_VERSION},
+ "region": "us-central1",
+ }
+
+ mock_tmp_dir.assert_called_once_with(prefix="apache-beam-go")
+ gcs_download_method.assert_has_calls(
+ [
+ call(bucket_name="bucket", object_name="path/to/main1",
filename=expected_launcher_binary),
+ call(bucket_name="bucket", object_name="path/to/main2",
filename=expected_worker_binary),
+ ],
+ )
+ assert os.access(expected_launcher_binary, os.X_OK)
+ assert os.access(expected_worker_binary, os.X_OK)
+
+ mock_dataflow_hook.assert_called_once_with(
+ gcp_conn_id=dataflow_config.gcp_conn_id,
+ delegate_to=dataflow_config.delegate_to,
+ poll_sleep=dataflow_config.poll_sleep,
+ impersonation_chain=dataflow_config.impersonation_chain,
+ drain_pipeline=dataflow_config.drain_pipeline,
+ cancel_timeout=dataflow_config.cancel_timeout,
+ wait_until_finished=dataflow_config.wait_until_finished,
+ )
+ provide_authorized_gcloud_method.assert_called_once_with()
+ start_go_pipeline_method.assert_called_once_with(
+ variables=expected_options,
+ launcher_binary=expected_launcher_binary,
+ worker_binary=expected_worker_binary,
+ process_line_callback=mock.ANY,
+ )
+ mock_persist_link.assert_called_once_with(
+ operator,
+ {},
+ dataflow_config.project_id,
+ dataflow_config.location,
+ operator.dataflow_job_id,
+ )
+ wait_for_done_method.assert_called_once_with(
+ job_name=expected_job_name,
+ location=dataflow_config.location,
+ job_id=operator.dataflow_job_id,
+ multiple_jobs=False,
+ project_id=dataflow_config.project_id,
+ )
+
@mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
@mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
@mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")