This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 08f6e2e5d2 Fix misconfiguration of redis client with ssl (#36561)
08f6e2e5d2 is described below
commit 08f6e2e5d21382494597e6cac66725bc85729656
Author: shohamy7 <[email protected]>
AuthorDate: Thu Jan 4 01:02:38 2024 +0200
Fix misconfiguration of redis client with ssl (#36561)
---
airflow/providers/redis/hooks/redis.py | 15 +++++++++-
tests/providers/redis/hooks/test_redis.py | 47 +++++++++++++++++++++++++++++--
2 files changed, 59 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/redis/hooks/redis.py
b/airflow/providers/redis/hooks/redis.py
index c149071ac4..ef447c3fe1 100644
--- a/airflow/providers/redis/hooks/redis.py
+++ b/airflow/providers/redis/hooks/redis.py
@@ -18,8 +18,11 @@
"""RedisHook module."""
from __future__ import annotations
+import warnings
+
from redis import Redis
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
@@ -68,11 +71,21 @@ class RedisHook(BaseHook):
"ssl_cert_reqs",
"ssl_ca_certs",
"ssl_keyfile",
- "ssl_cert_file",
+ "ssl_certfile",
"ssl_check_hostname",
]
ssl_args = {name: val for name, val in conn.extra_dejson.items() if
name in ssl_arg_names}
+ # This logic is for backward compatibility only
+ if "ssl_cert_file" in conn.extra_dejson and "ssl_certfile" not in
conn.extra_dejson:
+ warnings.warn(
+ "Extra parameter `ssl_cert_file` deprecated and will be
removed "
+ "in a future release. Please use `ssl_certfile` instead.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ ssl_args["ssl_certfile"] = conn.extra_dejson.get("ssl_cert_file")
+
if not self.redis:
self.log.debug(
'Initializing redis object for conn_id "%s" on %s:%s:%s',
diff --git a/tests/providers/redis/hooks/test_redis.py
b/tests/providers/redis/hooks/test_redis.py
index daab3daddb..de352a9a20 100644
--- a/tests/providers/redis/hooks/test_redis.py
+++ b/tests/providers/redis/hooks/test_redis.py
@@ -21,6 +21,7 @@ from unittest import mock
import pytest
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import Connection
from airflow.providers.redis.hooks.redis import RedisHook
@@ -28,6 +29,11 @@ pytestmark = pytest.mark.db_test
class TestRedisHook:
+ deprecation_message = (
+ "Extra parameter `ssl_cert_file` deprecated and will be removed "
+ "in a future release. Please use `ssl_certfile` instead."
+ )
+
def test_get_conn(self):
hook = RedisHook(redis_conn_id="redis_default")
assert hook.redis is None
@@ -52,7 +58,7 @@ class TestRedisHook:
"ssl_cert_reqs": "required",
"ssl_ca_certs": "/path/to/custom/ca-cert",
"ssl_keyfile": "/path/to/key-file",
- "ssl_cert_file": "/path/to/cert-file",
+ "ssl_certfile": "/path/to/cert-file",
"ssl_check_hostname": true
}""",
),
@@ -72,7 +78,44 @@ class TestRedisHook:
ssl_cert_reqs=connection.extra_dejson["ssl_cert_reqs"],
ssl_ca_certs=connection.extra_dejson["ssl_ca_certs"],
ssl_keyfile=connection.extra_dejson["ssl_keyfile"],
- ssl_cert_file=connection.extra_dejson["ssl_cert_file"],
+ ssl_certfile=connection.extra_dejson["ssl_certfile"],
+ ssl_check_hostname=connection.extra_dejson["ssl_check_hostname"],
+ )
+
+ @mock.patch("airflow.providers.redis.hooks.redis.Redis")
+ @mock.patch(
+ "airflow.providers.redis.hooks.redis.RedisHook.get_connection",
+ return_value=Connection(
+ password="password",
+ host="remote_host",
+ port=1234,
+ extra="""{
+ "db": 2,
+ "ssl": true,
+ "ssl_cert_reqs": "required",
+ "ssl_ca_certs": "/path/to/custom/ca-cert",
+ "ssl_keyfile": "/path/to/key-file",
+ "ssl_cert_file": "/path/to/cert-file",
+ "ssl_check_hostname": true
+ }""",
+ ),
+ )
+ def test_get_conn_with_deprecated_extra_config(self, mock_get_connection,
mock_redis):
+ connection = mock_get_connection.return_value
+ hook = RedisHook()
+
+ with pytest.warns(AirflowProviderDeprecationWarning,
match=self.deprecation_message):
+ hook.get_conn()
+ mock_redis.assert_called_once_with(
+ host=connection.host,
+ password=connection.password,
+ port=connection.port,
+ db=connection.extra_dejson["db"],
+ ssl=connection.extra_dejson["ssl"],
+ ssl_cert_reqs=connection.extra_dejson["ssl_cert_reqs"],
+ ssl_ca_certs=connection.extra_dejson["ssl_ca_certs"],
+ ssl_keyfile=connection.extra_dejson["ssl_keyfile"],
+ ssl_certfile=connection.extra_dejson["ssl_cert_file"],
ssl_check_hostname=connection.extra_dejson["ssl_check_hostname"],
)