This is an automated email from the ASF dual-hosted git repository.

taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0e1a789119 Improve `DockerOperator` to support multiple Docker hosts 
(#38466)
0e1a789119 is described below

commit 0e1a78911944482780ee92037dd68e204818f8a9
Author: oboki <[email protected]>
AuthorDate: Fri Apr 5 05:06:19 2024 +0900

    Improve `DockerOperator` to support multiple Docker hosts (#38466)
    
    * Allowing `docker_url` in `DockerOperator` and `base_url` in `DockerHook` 
to accept a list of str.
    * Iterating over multiple host URLs to create and attempt connection.
---
 airflow/providers/docker/hooks/docker.py     | 52 ++++++++++++++++++----------
 airflow/providers/docker/operators/docker.py |  4 +--
 tests/providers/docker/hooks/test_docker.py  |  9 +++++
 3 files changed, 45 insertions(+), 20 deletions(-)

diff --git a/airflow/providers/docker/hooks/docker.py 
b/airflow/providers/docker/hooks/docker.py
index 9e91b723af..fa1377bedd 100644
--- a/airflow/providers/docker/hooks/docker.py
+++ b/airflow/providers/docker/hooks/docker.py
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any
 
 from docker import APIClient, TLSConfig
 from docker.constants import DEFAULT_TIMEOUT_SECONDS
-from docker.errors import APIError
+from docker.errors import APIError, DockerException
 
 from airflow.exceptions import AirflowException, AirflowNotFoundException
 from airflow.hooks.base import BaseHook
@@ -45,7 +45,7 @@ class DockerHook(BaseHook):
 
     :param docker_conn_id: :ref:`Docker connection id 
<howto/connection:docker>` where stored credentials
          to Docker Registry. If set to ``None`` or empty then hook does not 
login to Container Registry.
-    :param base_url: URL to the Docker server.
+    :param base_url: URL or list of URLs to the Docker server.
     :param version: The version of the API to use. Use ``auto`` or ``None`` 
for automatically detect
         the server's version.
     :param tls: Is connection required TLS, for enable pass ``True`` for use 
with default options,
@@ -61,7 +61,7 @@ class DockerHook(BaseHook):
     def __init__(
         self,
         docker_conn_id: str | None = default_conn_name,
-        base_url: str | None = None,
+        base_url: str | list[str] | None = None,
         version: str | None = None,
         tls: TLSConfig | bool | None = None,
         timeout: int = DEFAULT_TIMEOUT_SECONDS,
@@ -69,12 +69,10 @@ class DockerHook(BaseHook):
         super().__init__()
         if not base_url:
             raise AirflowException("URL to the Docker server not provided.")
-        elif tls:
-            if base_url.startswith("tcp://"):
-                base_url = base_url.replace("tcp://", "https://";)
-                self.log.debug("Change `base_url` schema from 'tcp://' to 
'https://'.")
-            if not base_url.startswith("https://";):
-                self.log.warning("When `tls` specified then `base_url` 
expected 'https://' schema.")
+        if isinstance(base_url, str):
+            base_url = [base_url]
+        if tls:
+            base_url = list(map(self._redact_tls_schema, base_url))
 
         self.docker_conn_id = docker_conn_id
         self.__base_url = base_url
@@ -142,15 +140,25 @@ class DockerHook(BaseHook):
     @cached_property
     def api_client(self) -> APIClient:
         """Create connection to docker host and return ``docker.APIClient`` 
(cached)."""
-        client = APIClient(
-            base_url=self.__base_url, version=self.__version, tls=self.__tls, 
timeout=self.__timeout
-        )
-        if self.docker_conn_id:
-            # Obtain connection and try to login to Container Registry only if 
``docker_conn_id`` set.
-            self.__login(client, self.get_connection(self.docker_conn_id))
-
-        self._client_created = True
-        return client
+        for url in self.__base_url:
+            try:
+                client = APIClient(
+                    base_url=url, version=self.__version, tls=self.__tls, 
timeout=self.__timeout
+                )
+                if not client.ping():
+                    msg = f"Failed to ping host {url}."
+                    raise AirflowException(msg)
+                if self.docker_conn_id:
+                    # Obtain connection and try to login to Container Registry 
only if ``docker_conn_id`` set.
+                    self.__login(client, 
self.get_connection(self.docker_conn_id))
+            except APIError:
+                raise
+            except DockerException as e:
+                self.log.error("Failed to establish connection to Docker host 
%s: %s", url, e)
+            else:
+                self._client_created = True
+                return client
+        raise AirflowException("Failed to establish connection to any given 
Docker hosts.")
 
     @property
     def client_created(self) -> bool:
@@ -224,3 +232,11 @@ class DockerHook(BaseHook):
                 )
             },
         }
+
+    def _redact_tls_schema(self, url: str) -> str:
+        if url.startswith("tcp://"):
+            url = url.replace("tcp://", "https://";)
+            self.log.debug("Change `base_url` schema from 'tcp://' to 
'https://'.")
+        if not url.startswith("https://";):
+            self.log.warning("When `tls` specified then `base_url` expected 
'https://' schema.")
+        return url
diff --git a/airflow/providers/docker/operators/docker.py 
b/airflow/providers/docker/operators/docker.py
index 2714477791..f02a730e54 100644
--- a/airflow/providers/docker/operators/docker.py
+++ b/airflow/providers/docker/operators/docker.py
@@ -95,7 +95,7 @@ class DockerOperator(BaseOperator):
     :param cpus: Number of CPUs to assign to the container.
         This value gets multiplied with 1024. See
         https://docs.docker.com/engine/reference/run/#cpu-share-constraint
-    :param docker_url: URL of the host running the docker daemon.
+    :param docker_url: URL or list of URLs of the host(s) running the docker 
daemon.
         Default is the value of the ``DOCKER_HOST`` environment variable or 
unix://var/run/docker.sock
         if it is unset.
     :param environment: Environment variables to set in the container. 
(templated)
@@ -201,7 +201,7 @@ class DockerOperator(BaseOperator):
         command: str | list[str] | None = None,
         container_name: str | None = None,
         cpus: float = 1.0,
-        docker_url: str | None = None,
+        docker_url: str | list[str] | None = None,
         environment: dict | None = None,
         private_environment: dict | None = None,
         env_file: str | None = None,
diff --git a/tests/providers/docker/hooks/test_docker.py 
b/tests/providers/docker/hooks/test_docker.py
index 6f64e32d79..3e1774efa8 100644
--- a/tests/providers/docker/hooks/test_docker.py
+++ b/tests/providers/docker/hooks/test_docker.py
@@ -290,3 +290,12 @@ def test_construct_tls_config(assert_hostname, 
ssl_version):
             mock_tls_config.assert_called_once_with(
                 **expected_call_args, assert_hostname=assert_hostname, 
ssl_version=ssl_version
             )
+
+
[email protected](
+    "base_url", [["tcp://foo.bar.spam.egg", "unix:///foo/bar/spam.egg", 
"unix:///var/run/docker.sock"]]
+)
+def test_connect_to_valid_host(base_url):
+    """Test connect to valid host from a given list of hosts."""
+    hook = DockerHook(base_url=base_url, docker_conn_id=None)
+    assert hook.api_client.base_url == "http+docker://localhost"

Reply via email to