This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 29a59de237 Fix oversubsription of Redis pubsub sensor (#33139)
29a59de237 is described below

commit 29a59de237ccd42a3a5c20b10fc4c92b82ff4475
Author: Jarek Potiuk <[email protected]>
AuthorDate: Sat Aug 5 12:28:37 2023 +0200

    Fix oversubsription of Redis pubsub sensor (#33139)
    
    The fix in #32984 moved redis hook to be initialized in a
    cached property but "subscription" call has been moved to the
    poke - which would cause multiple subscribe calls in regular mode
    of the sensor (i.e. not poke-reschedule mode).
    
    This one fixes it by subscribing only once - when the cached
    property gets initialized.
    
    Fixes: #33138
---
 airflow/providers/redis/sensors/redis_pub_sub.py    |  5 +++--
 tests/providers/redis/sensors/test_redis_pub_sub.py | 21 +++++++++++++++++++++
 2 files changed, 24 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/redis/sensors/redis_pub_sub.py 
b/airflow/providers/redis/sensors/redis_pub_sub.py
index 4501758a61..67885fd1c1 100644
--- a/airflow/providers/redis/sensors/redis_pub_sub.py
+++ b/airflow/providers/redis/sensors/redis_pub_sub.py
@@ -45,7 +45,9 @@ class RedisPubSubSensor(BaseSensorOperator):
 
     @cached_property
     def pubsub(self):
-        return RedisHook(redis_conn_id=self.redis_conn_id).get_conn().pubsub()
+        hook = RedisHook(redis_conn_id=self.redis_conn_id).get_conn().pubsub()
+        hook.subscribe(self.channels)
+        return hook
 
     def poke(self, context: Context) -> bool:
         """
@@ -57,7 +59,6 @@ class RedisPubSubSensor(BaseSensorOperator):
         :return: ``True`` if message (with type 'message') is available or 
``False`` if not
         """
         self.log.info("RedisPubSubSensor checking for message on channels: 
%s", self.channels)
-        self.pubsub.subscribe(self.channels)
         message = self.pubsub.get_message()
         self.log.info("Message %s from channel %s", message, self.channels)
 
diff --git a/tests/providers/redis/sensors/test_redis_pub_sub.py 
b/tests/providers/redis/sensors/test_redis_pub_sub.py
index dae08797ba..f773e3bd51 100644
--- a/tests/providers/redis/sensors/test_redis_pub_sub.py
+++ b/tests/providers/redis/sensors/test_redis_pub_sub.py
@@ -72,3 +72,24 @@ class TestRedisPubSubSensor:
 
         context_calls = []
         assert self.mock_context["ti"].method_calls == context_calls, "context 
calls should be same"
+
+    @patch("airflow.providers.redis.hooks.redis.RedisHook.get_conn")
+    def test_poke_subscribe_called_only_once(self, mock_redis_conn):
+        sensor = RedisPubSubSensor(
+            task_id="test_task", dag=self.dag, channels="test", 
redis_conn_id="redis_default"
+        )
+
+        mock_redis_conn().pubsub().get_message.return_value = {
+            "type": "subscribe",
+            "channel": b"test",
+            "data": b"d1",
+        }
+
+        result = sensor.poke(self.mock_context)
+        assert not result
+
+        context_calls = []
+        assert self.mock_context["ti"].method_calls == context_calls, "context 
calls should be same"
+        result = sensor.poke(self.mock_context)
+
+        assert mock_redis_conn().pubsub().subscribe.call_count == 1

Reply via email to