This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8c4303e1ac Add support for running a Beam Go pipeline with an 
executable binary (#28764)
8c4303e1ac is described below

commit 8c4303e1ace0774244b556a8d86a19058af2b16d
Author: Johanna Öjeling <[email protected]>
AuthorDate: Wed Jan 18 18:52:05 2023 +0100

    Add support for running a Beam Go pipeline with an executable binary 
(#28764)
    
    The `BeamRunGoPipelineOperator` currently has a `go_file` parameter, which 
represents the path to a Go source file with the pipeline code. The operator 
starts the pipeline with `go run`, i.e. compiles the code into a temporary 
binary and executes.
    
    This PR adds support for the operator to start the pipeline with an already 
compiled binary, as an alternative to the source file approach. It introduces 
two new parameters:
    1. `launcher_binary` path to a binary compiled for the launching platform, 
i.e. the platform where Airflow is deployed
    2. `worker_binary` (optional) path to a binary compiled for the worker 
platform if using a remote runner
    
    Some motivations to introduce this feature:
    - It does not require a Go installation on the system where Airflow is run 
(which is more similar to how the `BeamRunJavaPipelineOperator` works, running 
a jar)
    - It does not involve the extra steps of initializing a Go module, 
installing dependences and compiling the code every task run, which is what 
currently happens when the Go source file is downloaded from GCS
    - In the current implementation only a single Go source file can downloaded 
from GCS. This can be limiting if the project comprises multiple files
---
 airflow/providers/apache/beam/hooks/beam.py        |  35 ++-
 airflow/providers/apache/beam/operators/beam.py    | 179 +++++++++++--
 docs/spelling_wordlist.txt                         |   1 +
 tests/providers/apache/beam/hooks/test_beam.py     |  29 ++
 tests/providers/apache/beam/operators/test_beam.py | 295 +++++++++++++++++++--
 5 files changed, 497 insertions(+), 42 deletions(-)

diff --git a/airflow/providers/apache/beam/hooks/beam.py 
b/airflow/providers/apache/beam/hooks/beam.py
index 28a5abc0c6..c318d17363 100644
--- a/airflow/providers/apache/beam/hooks/beam.py
+++ b/airflow/providers/apache/beam/hooks/beam.py
@@ -19,6 +19,7 @@
 from __future__ import annotations
 
 import contextlib
+import copy
 import json
 import os
 import select
@@ -310,11 +311,10 @@ class BeamHook(BaseHook):
         should_init_module: bool = False,
     ) -> None:
         """
-        Starts Apache Beam Go pipeline.
+        Starts Apache Beam Go pipeline with a source file.
 
         :param variables: Variables passed to the job.
         :param go_file: Path to the Go file with your beam pipeline.
-        :param go_file:
         :param process_line_callback: (optional) Callback that can be used to 
process each line of
             the stdout and stderr file descriptors.
         :param should_init_module: If False (default), will just execute a `go 
run` command. If True, will
@@ -346,3 +346,34 @@ class BeamHook(BaseHook):
             process_line_callback=process_line_callback,
             working_directory=working_directory,
         )
+
+    def start_go_pipeline_with_binary(
+        self,
+        variables: dict,
+        launcher_binary: str,
+        worker_binary: str,
+        process_line_callback: Callable[[str], None] | None = None,
+    ) -> None:
+        """
+        Starts Apache Beam Go pipeline with an executable binary.
+
+        :param variables: Variables passed to the job.
+        :param launcher_binary: Path to the binary compiled for the launching 
platform.
+        :param worker_binary: Path to the binary compiled for the worker 
platform.
+        :param process_line_callback: (optional) Callback that can be used to 
process each line of
+            the stdout and stderr file descriptors.
+        """
+        job_variables = copy.deepcopy(variables)
+
+        if "labels" in job_variables:
+            job_variables["labels"] = json.dumps(job_variables["labels"], 
separators=(",", ":"))
+
+        job_variables["worker_binary"] = worker_binary
+
+        command_prefix = [launcher_binary]
+
+        self._start_pipeline(
+            variables=job_variables,
+            command_prefix=command_prefix,
+            process_line_callback=process_line_callback,
+        )
diff --git a/airflow/providers/apache/beam/operators/beam.py 
b/airflow/providers/apache/beam/operators/beam.py
index e9395a8dde..b041dbdc24 100644
--- a/airflow/providers/apache/beam/operators/beam.py
+++ b/airflow/providers/apache/beam/operators/beam.py
@@ -19,9 +19,13 @@
 from __future__ import annotations
 
 import copy
+import os
+import stat
 import tempfile
-from abc import ABC, ABCMeta
+from abc import ABC, ABCMeta, abstractmethod
+from concurrent.futures import ThreadPoolExecutor, as_completed
 from contextlib import ExitStack
+from functools import partial
 from typing import TYPE_CHECKING, Callable, Sequence
 
 from airflow import AirflowException
@@ -31,10 +35,10 @@ from airflow.providers.google.cloud.hooks.dataflow import (
     DataflowHook,
     process_line_and_extract_dataflow_job_id_callback,
 )
-from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
 from airflow.providers.google.cloud.links.dataflow import DataflowJobLink
 from airflow.providers.google.cloud.operators.dataflow import CheckJobRunning, 
DataflowConfiguration
-from airflow.utils.helpers import convert_camel_to_snake
+from airflow.utils.helpers import convert_camel_to_snake, exactly_one
 from airflow.version import version
 
 if TYPE_CHECKING:
@@ -520,12 +524,27 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
         For more detail on Apache Beam have a look at the reference:
         https://beam.apache.org/documentation/
 
-    :param go_file: Reference to the Go Apache Beam pipeline e.g.,
-        /some/local/file/path/to/your/go/pipeline/file.go
+    :param go_file: Reference to the Apache Beam pipeline Go source file,
+        e.g. /local/path/to/main.go or gs://bucket/path/to/main.go.
+        Exactly one of go_file and launcher_binary must be provided.
+
+    :param launcher_binary: Reference to the Apache Beam pipeline Go binary 
compiled for the launching
+        platform, e.g. /local/path/to/launcher-main or 
gs://bucket/path/to/launcher-main.
+        Exactly one of go_file and launcher_binary must be provided.
+
+    :param worker_binary: Reference to the Apache Beam pipeline Go binary 
compiled for the worker platform,
+        e.g. /local/path/to/worker-main or gs://bucket/path/to/worker-main.
+        Needed if the OS or architecture of the workers running the pipeline 
is different from that
+        of the platform launching the pipeline. For more information, see the 
Apache Beam documentation
+        for Go cross compilation: 
https://beam.apache.org/documentation/sdks/go-cross-compilation/.
+        If launcher_binary is not set, providing a worker_binary will have no 
effect. If launcher_binary is
+        set and worker_binary is not, worker_binary will default to the value 
of launcher_binary.
     """
 
     template_fields = [
         "go_file",
+        "launcher_binary",
+        "worker_binary",
         "runner",
         "pipeline_options",
         "default_pipeline_options",
@@ -537,7 +556,9 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
     def __init__(
         self,
         *,
-        go_file: str,
+        go_file: str = "",
+        launcher_binary: str = "",
+        worker_binary: str = "",
         runner: str = "DirectRunner",
         default_pipeline_options: dict | None = None,
         pipeline_options: dict | None = None,
@@ -563,8 +584,13 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
             )
         self.dataflow_support_impersonation = False
 
+        if not exactly_one(go_file, launcher_binary):
+            raise ValueError("Exactly one of `go_file` and `launcher_binary` 
must be set")
+
         self.go_file = go_file
-        self.should_init_go_module = False
+        self.launcher_binary = launcher_binary
+        self.worker_binary = worker_binary or launcher_binary
+
         self.pipeline_options.setdefault("labels", {}).update(
             {"airflow-version": "v" + version.replace(".", "-").replace("+", 
"-")}
         )
@@ -581,24 +607,24 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
         if not self.beam_hook:
             raise AirflowException("Beam hook is not defined.")
 
+        go_artifact: _GoArtifact = (
+            _GoFile(file=self.go_file)
+            if self.go_file
+            else _GoBinary(launcher=self.launcher_binary, 
worker=self.worker_binary)
+        )
+
         with ExitStack() as exit_stack:
-            if self.go_file.lower().startswith("gs://"):
+            if go_artifact.is_located_on_gcs():
                 gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
-
                 tmp_dir = 
exit_stack.enter_context(tempfile.TemporaryDirectory(prefix="apache-beam-go"))
-                tmp_gcs_file = exit_stack.enter_context(
-                    gcs_hook.provide_file(object_url=self.go_file, dir=tmp_dir)
-                )
-                self.go_file = tmp_gcs_file.name
-                self.should_init_go_module = True
+                go_artifact.download_from_gcs(gcs_hook=gcs_hook, 
tmp_dir=tmp_dir)
 
             if is_dataflow and self.dataflow_hook:
                 with self.dataflow_hook.provide_authorized_gcloud():
-                    self.beam_hook.start_go_pipeline(
+                    go_artifact.start_pipeline(
+                        beam_hook=self.beam_hook,
                         variables=snake_case_pipeline_options,
-                        go_file=self.go_file,
                         process_line_callback=process_line_callback,
-                        should_init_module=self.should_init_go_module,
                     )
 
                 DataflowJobLink.persist(
@@ -618,11 +644,10 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
                     )
                 return {"dataflow_job_id": self.dataflow_job_id}
             else:
-                self.beam_hook.start_go_pipeline(
+                go_artifact.start_pipeline(
+                    beam_hook=self.beam_hook,
                     variables=snake_case_pipeline_options,
-                    go_file=self.go_file,
                     process_line_callback=process_line_callback,
-                    should_init_module=self.should_init_go_module,
                 )
 
     def on_kill(self) -> None:
@@ -632,3 +657,117 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
                 job_id=self.dataflow_job_id,
                 project_id=self.dataflow_config.project_id,
             )
+
+
+class _GoArtifact(ABC):
+    @abstractmethod
+    def is_located_on_gcs(self) -> bool:
+        ...
+
+    @abstractmethod
+    def download_from_gcs(self, gcs_hook: GCSHook, tmp_dir: str) -> None:
+        ...
+
+    @abstractmethod
+    def start_pipeline(
+        self,
+        beam_hook: BeamHook,
+        variables: dict,
+        process_line_callback: Callable[[str], None] | None = None,
+    ) -> None:
+        ...
+
+
+class _GoFile(_GoArtifact):
+    def __init__(self, file: str) -> None:
+        self.file = file
+        self.should_init_go_module = False
+
+    def is_located_on_gcs(self) -> bool:
+        return _object_is_located_on_gcs(self.file)
+
+    def download_from_gcs(self, gcs_hook: GCSHook, tmp_dir: str) -> None:
+        self.file = _download_object_from_gcs(gcs_hook=gcs_hook, 
uri=self.file, tmp_dir=tmp_dir)
+        self.should_init_go_module = True
+
+    def start_pipeline(
+        self,
+        beam_hook: BeamHook,
+        variables: dict,
+        process_line_callback: Callable[[str], None] | None = None,
+    ) -> None:
+        beam_hook.start_go_pipeline(
+            variables=variables,
+            go_file=self.file,
+            process_line_callback=process_line_callback,
+            should_init_module=self.should_init_go_module,
+        )
+
+
+class _GoBinary(_GoArtifact):
+    def __init__(self, launcher: str, worker: str) -> None:
+        self.launcher = launcher
+        self.worker = worker
+
+    def is_located_on_gcs(self) -> bool:
+        return any(_object_is_located_on_gcs(path) for path in (self.launcher, 
self.worker))
+
+    def download_from_gcs(self, gcs_hook: GCSHook, tmp_dir: str) -> None:
+        binaries_are_equal = self.launcher == self.worker
+
+        binaries_to_download = []
+
+        if _object_is_located_on_gcs(self.launcher):
+            binaries_to_download.append("launcher")
+
+        if not binaries_are_equal and _object_is_located_on_gcs(self.worker):
+            binaries_to_download.append("worker")
+
+        download_fn = partial(_download_object_from_gcs, gcs_hook=gcs_hook, 
tmp_dir=tmp_dir)
+
+        with ThreadPoolExecutor(max_workers=len(binaries_to_download)) as 
executor:
+            futures = {
+                executor.submit(download_fn, uri=getattr(self, binary), 
tmp_prefix=f"{binary}-"): binary
+                for binary in binaries_to_download
+            }
+
+            for future in as_completed(futures):
+                binary = futures[future]
+                tmp_path = future.result()
+                _make_executable(tmp_path)
+                setattr(self, binary, tmp_path)
+
+        if binaries_are_equal:
+            self.worker = self.launcher
+
+    def start_pipeline(
+        self,
+        beam_hook: BeamHook,
+        variables: dict,
+        process_line_callback: Callable[[str], None] | None = None,
+    ) -> None:
+        beam_hook.start_go_pipeline_with_binary(
+            variables=variables,
+            launcher_binary=self.launcher,
+            worker_binary=self.worker,
+            process_line_callback=process_line_callback,
+        )
+
+
+def _object_is_located_on_gcs(path: str) -> bool:
+    return path.lower().startswith("gs://")
+
+
+def _download_object_from_gcs(gcs_hook: GCSHook, uri: str, tmp_dir: str, 
tmp_prefix: str = "") -> str:
+    tmp_name = f"{tmp_prefix}{os.path.basename(uri)}"
+    tmp_path = os.path.join(tmp_dir, tmp_name)
+
+    bucket, prefix = _parse_gcs_url(uri)
+    gcs_hook.download(bucket_name=bucket, object_name=prefix, 
filename=tmp_path)
+
+    return tmp_path
+
+
+def _make_executable(path: str) -> None:
+    st = os.stat(path)
+    os.chmod(path, st.st_mode | stat.S_IEXEC)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 68b19e1192..28a34b76dc 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1269,6 +1269,7 @@ schedulable
 schedulername
 schemas
 sdk
+sdks
 searchpath
 SearchResultGenerator
 SecretManagerClient
diff --git a/tests/providers/apache/beam/hooks/test_beam.py 
b/tests/providers/apache/beam/hooks/test_beam.py
index f509c18b0a..80cf26687d 100644
--- a/tests/providers/apache/beam/hooks/test_beam.py
+++ b/tests/providers/apache/beam/hooks/test_beam.py
@@ -312,6 +312,35 @@ class TestBeamHook:
                 variables=copy.deepcopy(BEAM_VARIABLES_GO),
             )
 
+    @mock.patch(BEAM_STRING.format("BeamCommandRunner"))
+    def test_start_go_pipeline_with_binary(self, mock_runner):
+        hook = BeamHook(runner=DEFAULT_RUNNER)
+        wait_for_done_method = mock_runner.return_value.wait_for_done
+        process_line_callback = MagicMock()
+
+        launcher_binary = "/path/to/launcher-main"
+        worker_binary = "/path/to/worker-main"
+
+        hook.start_go_pipeline_with_binary(
+            variables=BEAM_VARIABLES_GO,
+            launcher_binary=launcher_binary,
+            worker_binary=worker_binary,
+            process_line_callback=process_line_callback,
+        )
+
+        expected_cmd = [
+            launcher_binary,
+            f"--runner={DEFAULT_RUNNER}",
+            "--output=gs://test/output",
+            '--labels={"foo":"bar"}',
+            f"--worker_binary={worker_binary}",
+        ]
+
+        mock_runner.assert_called_once_with(
+            cmd=expected_cmd, process_line_callback=process_line_callback, 
working_directory=None
+        )
+        wait_for_done_method.assert_called_once_with()
+
 
 class TestBeamRunner:
     
@mock.patch("airflow.providers.apache.beam.hooks.beam.BeamCommandRunner.log")
diff --git a/tests/providers/apache/beam/operators/test_beam.py 
b/tests/providers/apache/beam/operators/test_beam.py
index d5e0bfd58d..4cefb756b4 100644
--- a/tests/providers/apache/beam/operators/test_beam.py
+++ b/tests/providers/apache/beam/operators/test_beam.py
@@ -16,8 +16,11 @@
 # under the License.
 from __future__ import annotations
 
+import os
 from unittest import mock
-from unittest.mock import MagicMock
+from unittest.mock import MagicMock, call
+
+import pytest
 
 from airflow.providers.apache.beam.operators.beam import (
     BeamRunGoPipelineOperator,
@@ -37,7 +40,9 @@ PY_FILE = "gs://my-bucket/my-object.py"
 PY_INTERPRETER = "python3"
 PY_OPTIONS = ["-m"]
 GO_FILE = "gs://my-bucket/example/main.go"
-DEFAULT_OPTIONS_PYTHON = DEFAULT_OPTIONS_JAVA = {
+LAUNCHER_BINARY = "gs://my-bucket/example/launcher"
+WORKER_BINARY = "gs://my-bucket/example/worker"
+DEFAULT_OPTIONS = {
     "project": "test",
     "stagingLocation": "gs://test/staging",
 }
@@ -56,7 +61,7 @@ class TestBeamRunPythonPipelineOperator:
             task_id=TASK_ID,
             py_file=PY_FILE,
             py_options=PY_OPTIONS,
-            default_pipeline_options=DEFAULT_OPTIONS_PYTHON,
+            default_pipeline_options=DEFAULT_OPTIONS,
             pipeline_options=ADDITIONAL_OPTIONS,
         )
 
@@ -67,7 +72,7 @@ class TestBeamRunPythonPipelineOperator:
         assert self.operator.runner == DEFAULT_RUNNER
         assert self.operator.py_options == PY_OPTIONS
         assert self.operator.py_interpreter == PY_INTERPRETER
-        assert self.operator.default_pipeline_options == DEFAULT_OPTIONS_PYTHON
+        assert self.operator.default_pipeline_options == DEFAULT_OPTIONS
         assert self.operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS
 
     @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
@@ -185,7 +190,7 @@ class TestBeamRunJavaPipelineOperator:
             task_id=TASK_ID,
             jar=JAR_FILE,
             job_class=JOB_CLASS,
-            default_pipeline_options=DEFAULT_OPTIONS_JAVA,
+            default_pipeline_options=DEFAULT_OPTIONS,
             pipeline_options=ADDITIONAL_OPTIONS,
         )
 
@@ -193,7 +198,7 @@ class TestBeamRunJavaPipelineOperator:
         """Test BeamRunJavaPipelineOperator instance is properly 
initialized."""
         assert self.operator.task_id == TASK_ID
         assert self.operator.runner == DEFAULT_RUNNER
-        assert self.operator.default_pipeline_options == DEFAULT_OPTIONS_JAVA
+        assert self.operator.default_pipeline_options == DEFAULT_OPTIONS
         assert self.operator.job_class == JOB_CLASS
         assert self.operator.jar == JAR_FILE
         assert self.operator.pipeline_options == ADDITIONAL_OPTIONS
@@ -211,7 +216,7 @@ class TestBeamRunJavaPipelineOperator:
         beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
         gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
         start_java_hook.assert_called_once_with(
-            variables={**DEFAULT_OPTIONS_JAVA, **ADDITIONAL_OPTIONS},
+            variables={**DEFAULT_OPTIONS, **ADDITIONAL_OPTIONS},
             jar=gcs_provide_file.return_value.__enter__.return_value.name,
             job_class=JOB_CLASS,
             process_line_callback=None,
@@ -303,30 +308,96 @@ class TestBeamRunGoPipelineOperator:
         self.operator = BeamRunGoPipelineOperator(
             task_id=TASK_ID,
             go_file=GO_FILE,
-            default_pipeline_options=DEFAULT_OPTIONS_PYTHON,
+            default_pipeline_options=DEFAULT_OPTIONS,
             pipeline_options=ADDITIONAL_OPTIONS,
         )
 
-    def test_init(self):
-        """Test BeamRunGoPipelineOperator instance is properly initialized."""
+    def test_init_with_go_file(self):
+        """Test BeamRunGoPipelineOperator instance is properly initialized 
with go_file."""
         assert self.operator.task_id == TASK_ID
         assert self.operator.go_file == GO_FILE
+        assert self.operator.launcher_binary == ""
+        assert self.operator.worker_binary == ""
         assert self.operator.runner == DEFAULT_RUNNER
-        assert self.operator.default_pipeline_options == DEFAULT_OPTIONS_PYTHON
+        assert self.operator.default_pipeline_options == DEFAULT_OPTIONS
         assert self.operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS
 
+    def test_init_with_launcher_binary(self):
+        """Test BeamRunGoPipelineOperator instance is properly initialized 
with launcher_binary."""
+        operator = BeamRunGoPipelineOperator(
+            task_id=TASK_ID,
+            launcher_binary=LAUNCHER_BINARY,
+            default_pipeline_options=DEFAULT_OPTIONS,
+            pipeline_options=ADDITIONAL_OPTIONS,
+        )
+
+        assert operator.task_id == TASK_ID
+        assert operator.go_file == ""
+        assert operator.launcher_binary == LAUNCHER_BINARY
+        assert operator.worker_binary == LAUNCHER_BINARY
+        assert operator.runner == DEFAULT_RUNNER
+        assert operator.default_pipeline_options == DEFAULT_OPTIONS
+        assert operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS
+
+    def test_init_with_launcher_binary_and_worker_binary(self):
+        """
+        Test BeamRunGoPipelineOperator instance is properly initialized with 
launcher_binary and
+        worker_binary.
+        """
+        operator = BeamRunGoPipelineOperator(
+            task_id=TASK_ID,
+            launcher_binary=LAUNCHER_BINARY,
+            worker_binary=WORKER_BINARY,
+            default_pipeline_options=DEFAULT_OPTIONS,
+            pipeline_options=ADDITIONAL_OPTIONS,
+        )
+
+        assert operator.task_id == TASK_ID
+        assert operator.go_file == ""
+        assert operator.launcher_binary == LAUNCHER_BINARY
+        assert operator.worker_binary == WORKER_BINARY
+        assert operator.runner == DEFAULT_RUNNER
+        assert operator.default_pipeline_options == DEFAULT_OPTIONS
+        assert operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS
+
+    def test_init_with_neither_go_file_nor_launcher_binary_raises(self):
+        """
+        Test BeamRunGoPipelineOperator initialization raises ValueError when 
neither
+        go_file nor launcher_binary is provided.
+        """
+        with pytest.raises(ValueError, match="Exactly one of `go_file` and 
`launcher_binary` must be set"):
+            BeamRunGoPipelineOperator(
+                task_id=TASK_ID,
+                default_pipeline_options=DEFAULT_OPTIONS,
+                pipeline_options=ADDITIONAL_OPTIONS,
+            )
+
+    def test_init_with_both_go_file_and_launcher_binary_raises(self):
+        """
+        Test BeamRunGoPipelineOperator initialization raises ValueError when 
both of
+        go_file and launcher_binary are provided.
+        """
+        with pytest.raises(ValueError, match="Exactly one of `go_file` and 
`launcher_binary` must be set"):
+            BeamRunGoPipelineOperator(
+                task_id=TASK_ID,
+                go_file=GO_FILE,
+                launcher_binary=LAUNCHER_BINARY,
+                default_pipeline_options=DEFAULT_OPTIONS,
+                pipeline_options=ADDITIONAL_OPTIONS,
+            )
+
     @mock.patch(
         "tempfile.TemporaryDirectory",
         
return_value=MagicMock(__enter__=MagicMock(return_value="/tmp/apache-beam-go")),
     )
     @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
     @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
-    def test_exec_direct_runner(self, gcs_hook, beam_hook_mock, _):
+    def test_exec_direct_runner_with_gcs_go_file(self, gcs_hook, 
beam_hook_mock, _):
         """Test BeamHook is created and the right args are passed to
         start_go_workflow.
         """
         start_go_pipeline_method = 
beam_hook_mock.return_value.start_go_pipeline
-        gcs_provide_file_method = gcs_hook.return_value.provide_file
+        gcs_download_method = gcs_hook.return_value.download
         self.operator.execute(None)
         beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
         expected_options = {
@@ -335,17 +406,75 @@ class TestBeamRunGoPipelineOperator:
             "output": "gs://test/output",
             "labels": {"foo": "bar", "airflow-version": TEST_VERSION},
         }
-        gcs_provide_file_method.assert_called_once_with(object_url=GO_FILE, 
dir="/tmp/apache-beam-go")
+        expected_go_file = "/tmp/apache-beam-go/main.go"
+        gcs_download_method.assert_called_once_with(
+            bucket_name="my-bucket", object_name="example/main.go", 
filename=expected_go_file
+        )
         start_go_pipeline_method.assert_called_once_with(
             variables=expected_options,
-            
go_file=gcs_provide_file_method.return_value.__enter__.return_value.name,
+            go_file=expected_go_file,
             process_line_callback=None,
             should_init_module=True,
         )
 
+    @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+    @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
+    @mock.patch("tempfile.TemporaryDirectory")
+    def test_exec_direct_runner_with_gcs_launcher_binary(
+        self, mock_tmp_dir, mock_beam_hook, mock_gcs_hook, tmp_path
+    ):
+        """
+        Test start_go_pipeline_from_binary is called with an executable 
launcher binary downloaded from GCS.
+        """
+
+        def tmp_dir_side_effect(prefix: str) -> str:
+            sub_dir = tmp_path / mock_tmp_dir.call_args[1]["prefix"]
+            sub_dir.mkdir()
+            return str(sub_dir)
+
+        mock_tmp_dir.return_value.__enter__.side_effect = tmp_dir_side_effect
+
+        def gcs_download_side_effect(bucket_name: str, object_name: str, 
filename: str) -> None:
+            open(filename, "wb").close()
+
+        gcs_download_method = mock_gcs_hook.return_value.download
+        gcs_download_method.side_effect = gcs_download_side_effect
+
+        start_go_pipeline_method = 
mock_beam_hook.return_value.start_go_pipeline_with_binary
+
+        operator = BeamRunGoPipelineOperator(
+            task_id=TASK_ID,
+            launcher_binary="gs://bucket/path/to/main",
+            default_pipeline_options=DEFAULT_OPTIONS,
+            pipeline_options=ADDITIONAL_OPTIONS,
+        )
+        operator.execute({})
+
+        expected_binary = f"{tmp_path}/apache-beam-go/launcher-main"
+        expected_options = {
+            "project": "test",
+            "staging_location": "gs://test/staging",
+            "output": "gs://test/output",
+            "labels": {"foo": "bar", "airflow-version": TEST_VERSION},
+        }
+        mock_beam_hook.assert_called_once_with(runner=DEFAULT_RUNNER)
+        mock_tmp_dir.assert_called_once_with(prefix="apache-beam-go")
+        gcs_download_method.assert_called_once_with(
+            bucket_name="bucket",
+            object_name="path/to/main",
+            filename=expected_binary,
+        )
+        assert os.access(expected_binary, os.X_OK)
+        start_go_pipeline_method.assert_called_once_with(
+            variables=expected_options,
+            launcher_binary=expected_binary,
+            worker_binary=expected_binary,
+            process_line_callback=None,
+        )
+
     @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
     @mock.patch("airflow.providers.google.go_module_utils.init_module")
-    def test_exec_source_on_local_path(self, init_module, beam_hook_mock):
+    def test_exec_direct_runner_with_local_go_file(self, init_module, 
beam_hook_mock):
         """
         Check that start_go_pipeline is called without initializing the Go 
module when source is locale.
         """
@@ -365,6 +494,29 @@ class TestBeamRunGoPipelineOperator:
             should_init_module=False,
         )
 
+    @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
+    def test_exec_direct_runner_with_local_launcher_binary(self, 
mock_beam_hook):
+        """
+        Test start_go_pipeline_with_binary is called with a local launcher 
binary.
+        """
+        start_go_pipeline_method = 
mock_beam_hook.return_value.start_go_pipeline_with_binary
+
+        operator = BeamRunGoPipelineOperator(
+            task_id=TASK_ID,
+            launcher_binary="/local/path/to/main",
+        )
+        operator.execute({})
+
+        expected_binary = "/local/path/to/main"
+
+        mock_beam_hook.assert_called_once_with(runner=DEFAULT_RUNNER)
+        start_go_pipeline_method.assert_called_once_with(
+            variables={"labels": {"airflow-version": TEST_VERSION}},
+            launcher_binary=expected_binary,
+            worker_binary=expected_binary,
+            process_line_callback=None,
+        )
+
     
@mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
     @mock.patch(
         "tempfile.TemporaryDirectory",
@@ -373,14 +525,16 @@ class TestBeamRunGoPipelineOperator:
     @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
     @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
     @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
-    def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, 
beam_hook_mock, _, persist_link_mock):
+    def test_exec_dataflow_runner_with_go_file(
+        self, gcs_hook, dataflow_hook_mock, beam_hook_mock, _, 
persist_link_mock
+    ):
         """Test DataflowHook is created and the right args are passed to
         start_go_dataflow.
         """
         dataflow_config = 
DataflowConfiguration(impersonation_chain="[email protected]")
         self.operator.runner = "DataflowRunner"
         self.operator.dataflow_config = dataflow_config
-        gcs_provide_file = gcs_hook.return_value.provide_file
+        gcs_download_method = gcs_hook.return_value.download
         self.operator.execute(None)
         job_name = dataflow_hook_mock.build_dataflow_job_name.return_value
         dataflow_hook_mock.assert_called_once_with(
@@ -407,10 +561,13 @@ class TestBeamRunGoPipelineOperator:
             expected_options["region"],
             self.operator.dataflow_job_id,
         )
-        gcs_provide_file.assert_called_once_with(object_url=GO_FILE, 
dir="/tmp/apache-beam-go")
+        expected_go_file = "/tmp/apache-beam-go/main.go"
+        gcs_download_method.assert_called_once_with(
+            bucket_name="my-bucket", object_name="example/main.go", 
filename=expected_go_file
+        )
         beam_hook_mock.return_value.start_go_pipeline.assert_called_once_with(
             variables=expected_options,
-            go_file=gcs_provide_file.return_value.__enter__.return_value.name,
+            go_file=expected_go_file,
             process_line_callback=mock.ANY,
             should_init_module=True,
         )
@@ -423,6 +580,104 @@ class TestBeamRunGoPipelineOperator:
         )
         
dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with()
 
+    
@mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
+    @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
+    @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+    @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
+    @mock.patch("tempfile.TemporaryDirectory")
+    def test_exec_dataflow_runner_with_launcher_binary_and_worker_binary(
+        self, mock_tmp_dir, mock_beam_hook, mock_gcs_hook, mock_dataflow_hook, 
mock_persist_link, tmp_path
+    ):
+        """
+        Test DataflowHook is created and start_go_pipeline_from_binary is 
called with
+        a launcher binary and a worker binary.
+        """
+
+        def tmp_dir_side_effect(prefix: str) -> str:
+            sub_dir = tmp_path / mock_tmp_dir.call_args[1]["prefix"]
+            sub_dir.mkdir()
+            return str(sub_dir)
+
+        mock_tmp_dir.return_value.__enter__.side_effect = tmp_dir_side_effect
+
+        def gcs_download_side_effect(bucket_name: str, object_name: str, 
filename: str) -> None:
+            open(filename, "wb").close()
+
+        gcs_download_method = mock_gcs_hook.return_value.download
+        gcs_download_method.side_effect = gcs_download_side_effect
+
+        mock_dataflow_hook.build_dataflow_job_name.return_value = "test-job"
+
+        provide_authorized_gcloud_method = 
mock_dataflow_hook.return_value.provide_authorized_gcloud
+        start_go_pipeline_method = 
mock_beam_hook.return_value.start_go_pipeline_with_binary
+        wait_for_done_method = mock_dataflow_hook.return_value.wait_for_done
+
+        dataflow_config = DataflowConfiguration(project_id="test-project")
+
+        operator = BeamRunGoPipelineOperator(
+            task_id=TASK_ID,
+            launcher_binary="gs://bucket/path/to/main1",
+            worker_binary="gs://bucket/path/to/main2",
+            runner="DataflowRunner",
+            default_pipeline_options=DEFAULT_OPTIONS,
+            pipeline_options=ADDITIONAL_OPTIONS,
+            dataflow_config=dataflow_config,
+        )
+        operator.execute({})
+
+        expected_launcher_binary = str(tmp_path / 
"apache-beam-go/launcher-main1")
+        expected_worker_binary = str(tmp_path / "apache-beam-go/worker-main2")
+        expected_job_name = "test-job"
+        expected_options = {
+            "project": "test-project",
+            "job_name": expected_job_name,
+            "staging_location": "gs://test/staging",
+            "output": "gs://test/output",
+            "labels": {"foo": "bar", "airflow-version": TEST_VERSION},
+            "region": "us-central1",
+        }
+
+        mock_tmp_dir.assert_called_once_with(prefix="apache-beam-go")
+        gcs_download_method.assert_has_calls(
+            [
+                call(bucket_name="bucket", object_name="path/to/main1", 
filename=expected_launcher_binary),
+                call(bucket_name="bucket", object_name="path/to/main2", 
filename=expected_worker_binary),
+            ],
+        )
+        assert os.access(expected_launcher_binary, os.X_OK)
+        assert os.access(expected_worker_binary, os.X_OK)
+
+        mock_dataflow_hook.assert_called_once_with(
+            gcp_conn_id=dataflow_config.gcp_conn_id,
+            delegate_to=dataflow_config.delegate_to,
+            poll_sleep=dataflow_config.poll_sleep,
+            impersonation_chain=dataflow_config.impersonation_chain,
+            drain_pipeline=dataflow_config.drain_pipeline,
+            cancel_timeout=dataflow_config.cancel_timeout,
+            wait_until_finished=dataflow_config.wait_until_finished,
+        )
+        provide_authorized_gcloud_method.assert_called_once_with()
+        start_go_pipeline_method.assert_called_once_with(
+            variables=expected_options,
+            launcher_binary=expected_launcher_binary,
+            worker_binary=expected_worker_binary,
+            process_line_callback=mock.ANY,
+        )
+        mock_persist_link.assert_called_once_with(
+            operator,
+            {},
+            dataflow_config.project_id,
+            dataflow_config.location,
+            operator.dataflow_job_id,
+        )
+        wait_for_done_method.assert_called_once_with(
+            job_name=expected_job_name,
+            location=dataflow_config.location,
+            job_id=operator.dataflow_job_id,
+            multiple_jobs=False,
+            project_id=dataflow_config.project_id,
+        )
+
     
@mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
     @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
     @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")


Reply via email to