This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 08f6e2e5d2 Fix misconfiguration of redis client with ssl (#36561)
08f6e2e5d2 is described below

commit 08f6e2e5d21382494597e6cac66725bc85729656
Author: shohamy7 <[email protected]>
AuthorDate: Thu Jan 4 01:02:38 2024 +0200

    Fix misconfiguration of redis client with ssl (#36561)
---
 airflow/providers/redis/hooks/redis.py    | 15 +++++++++-
 tests/providers/redis/hooks/test_redis.py | 47 +++++++++++++++++++++++++++++--
 2 files changed, 59 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/redis/hooks/redis.py 
b/airflow/providers/redis/hooks/redis.py
index c149071ac4..ef447c3fe1 100644
--- a/airflow/providers/redis/hooks/redis.py
+++ b/airflow/providers/redis/hooks/redis.py
@@ -18,8 +18,11 @@
 """RedisHook module."""
 from __future__ import annotations
 
+import warnings
+
 from redis import Redis
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.hooks.base import BaseHook
 
 
@@ -68,11 +71,21 @@ class RedisHook(BaseHook):
             "ssl_cert_reqs",
             "ssl_ca_certs",
             "ssl_keyfile",
-            "ssl_cert_file",
+            "ssl_certfile",
             "ssl_check_hostname",
         ]
         ssl_args = {name: val for name, val in conn.extra_dejson.items() if 
name in ssl_arg_names}
 
+        # This logic is for backward compatibility only
+        if "ssl_cert_file" in conn.extra_dejson and "ssl_certfile" not in 
conn.extra_dejson:
+            warnings.warn(
+                "Extra parameter `ssl_cert_file` deprecated and will be 
removed "
+                "in a future release. Please use `ssl_certfile` instead.",
+                AirflowProviderDeprecationWarning,
+                stacklevel=2,
+            )
+            ssl_args["ssl_certfile"] = conn.extra_dejson.get("ssl_cert_file")
+
         if not self.redis:
             self.log.debug(
                 'Initializing redis object for conn_id "%s" on %s:%s:%s',
diff --git a/tests/providers/redis/hooks/test_redis.py 
b/tests/providers/redis/hooks/test_redis.py
index daab3daddb..de352a9a20 100644
--- a/tests/providers/redis/hooks/test_redis.py
+++ b/tests/providers/redis/hooks/test_redis.py
@@ -21,6 +21,7 @@ from unittest import mock
 
 import pytest
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.models import Connection
 from airflow.providers.redis.hooks.redis import RedisHook
 
@@ -28,6 +29,11 @@ pytestmark = pytest.mark.db_test
 
 
 class TestRedisHook:
+    deprecation_message = (
+        "Extra parameter `ssl_cert_file` deprecated and will be removed "
+        "in a future release. Please use `ssl_certfile` instead."
+    )
+
     def test_get_conn(self):
         hook = RedisHook(redis_conn_id="redis_default")
         assert hook.redis is None
@@ -52,7 +58,7 @@ class TestRedisHook:
                         "ssl_cert_reqs": "required",
                         "ssl_ca_certs": "/path/to/custom/ca-cert",
                         "ssl_keyfile": "/path/to/key-file",
-                        "ssl_cert_file": "/path/to/cert-file",
+                        "ssl_certfile": "/path/to/cert-file",
                         "ssl_check_hostname": true
                     }""",
         ),
@@ -72,7 +78,44 @@ class TestRedisHook:
             ssl_cert_reqs=connection.extra_dejson["ssl_cert_reqs"],
             ssl_ca_certs=connection.extra_dejson["ssl_ca_certs"],
             ssl_keyfile=connection.extra_dejson["ssl_keyfile"],
-            ssl_cert_file=connection.extra_dejson["ssl_cert_file"],
+            ssl_certfile=connection.extra_dejson["ssl_certfile"],
+            ssl_check_hostname=connection.extra_dejson["ssl_check_hostname"],
+        )
+
+    @mock.patch("airflow.providers.redis.hooks.redis.Redis")
+    @mock.patch(
+        "airflow.providers.redis.hooks.redis.RedisHook.get_connection",
+        return_value=Connection(
+            password="password",
+            host="remote_host",
+            port=1234,
+            extra="""{
+                        "db": 2,
+                        "ssl": true,
+                        "ssl_cert_reqs": "required",
+                        "ssl_ca_certs": "/path/to/custom/ca-cert",
+                        "ssl_keyfile": "/path/to/key-file",
+                        "ssl_cert_file": "/path/to/cert-file",
+                        "ssl_check_hostname": true
+                    }""",
+        ),
+    )
+    def test_get_conn_with_deprecated_extra_config(self, mock_get_connection, 
mock_redis):
+        connection = mock_get_connection.return_value
+        hook = RedisHook()
+
+        with pytest.warns(AirflowProviderDeprecationWarning, 
match=self.deprecation_message):
+            hook.get_conn()
+        mock_redis.assert_called_once_with(
+            host=connection.host,
+            password=connection.password,
+            port=connection.port,
+            db=connection.extra_dejson["db"],
+            ssl=connection.extra_dejson["ssl"],
+            ssl_cert_reqs=connection.extra_dejson["ssl_cert_reqs"],
+            ssl_ca_certs=connection.extra_dejson["ssl_ca_certs"],
+            ssl_keyfile=connection.extra_dejson["ssl_keyfile"],
+            ssl_certfile=connection.extra_dejson["ssl_cert_file"],
             ssl_check_hostname=connection.extra_dejson["ssl_check_hostname"],
         )
 

Reply via email to