This is an automated email from the ASF dual-hosted git repository.

utkarsharma pushed a commit to branch v2-9-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 67cb21ac1801d0cf4bbab64085f3a0136190c006
Author: Sanket <[email protected]>
AuthorDate: Sun Jun 16 23:45:58 2024 +0530

    Exponential Backoff Not Functioning in BaseSensorOperator Reschedule Mode 
(#39823)
    
    * Fix: Increment try_number/poke_count in BaseSensorOperator for correct 
exponential backoff in reschedule mode.
    
    * exponential backoff handling for reschedule mode
    
    ---------
    
    Co-authored-by: sanket2000 <[email protected]>
    (cherry picked from commit 841b28cccbb32b682333dc9d27b1b5f04fc495ab)
---
 airflow/sensors/base.py        | 31 ++++++++++++++++++++
 newsfragments/39823.bugfix.rst |  1 +
 tests/sensors/test_base.py     | 65 ++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 97 insertions(+)

diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py
index a980484672..0f8ce5c9f8 100644
--- a/airflow/sensors/base.py
+++ b/airflow/sensors/base.py
@@ -312,6 +312,37 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
         if not self.exponential_backoff:
             return self.poke_interval
 
+        if self.reschedule:
+            # Calculate elapsed time since the sensor started
+            elapsed_time = run_duration()
+
+            # Initialize variables for the simulation
+            cumulative_time: float = 0.0
+            estimated_poke_count: int = 0
+
+            while cumulative_time <= elapsed_time:
+                estimated_poke_count += 1
+                # Calculate min_backoff for the current try number
+                min_backoff = max(int(self.poke_interval * (2 ** 
(estimated_poke_count - 2))), 1)
+
+                # Calculate the jitter
+                run_hash = int(
+                    hashlib.sha1(
+                        
f"{self.dag_id}#{self.task_id}#{started_at}#{estimated_poke_count}".encode()
+                    ).hexdigest(),
+                    16,
+                )
+                modded_hash = min_backoff + run_hash % min_backoff
+
+                # Calculate the jitter, which is used to prevent multiple 
sensors simultaneously poking
+                interval_with_jitter = min(modded_hash, 
timedelta.max.total_seconds() - 1)
+
+                # Add the interval to the cumulative time
+                cumulative_time += interval_with_jitter
+
+            # Now we have an estimated_poke_count based on the elapsed time
+            poke_count = estimated_poke_count or poke_count
+
         # The value of min_backoff should always be greater than or equal to 1.
         min_backoff = max(int(self.poke_interval * (2 ** (try_number - 2))), 1)
 
diff --git a/newsfragments/39823.bugfix.rst b/newsfragments/39823.bugfix.rst
new file mode 100644
index 0000000000..7a774258a4
--- /dev/null
+++ b/newsfragments/39823.bugfix.rst
@@ -0,0 +1 @@
+Fixed ``BaseSensorOperator`` with exponential backoff and reschedule mode by 
estimating try number based on ``run_duration``; previously, sensors had a 
fixed reschedule interval.
diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py
index e3afcd3564..c4406ad643 100644
--- a/tests/sensors/test_base.py
+++ b/tests/sensors/test_base.py
@@ -414,6 +414,71 @@ class TestBaseSensor:
             if ti.task_id == DUMMY_OP:
                 assert ti.state == State.NONE
 
+    def test_ok_with_reschedule_and_exponential_backoff(
+        self, make_sensor, time_machine, task_reschedules_for_ti, session
+    ):
+        sensor, dr = make_sensor(
+            return_value=None,
+            poke_interval=10,
+            timeout=36000,
+            mode="reschedule",
+            exponential_backoff=True,
+        )
+
+        def _get_tis():
+            tis = dr.get_task_instances(session=session)
+            assert len(tis) == 2
+            yield next(x for x in tis if x.task_id == SENSOR_OP)
+            yield next(x for x in tis if x.task_id == DUMMY_OP)
+
+        false_count = 10
+        sensor.poke = Mock(side_effect=[False] * false_count + [True])
+
+        task_start_date = timezone.utcnow()
+
+        time_machine.move_to(task_start_date, tick=False)
+        curr_date = task_start_date
+
+        def run_duration():
+            return (timezone.utcnow() - task_start_date).total_seconds()
+
+        new_interval = 0
+
+        sensor_ti, dummy_ti = _get_tis()
+        assert dummy_ti.state == State.NONE
+        assert sensor_ti.state == State.NONE
+
+        # ordinarily the scheduler does this
+        sensor_ti.state = State.SCHEDULED
+        sensor_ti.try_number += 1  # first TI run
+        session.commit()
+
+        # loop poke returns false
+        for _poke_count in range(1, false_count + 1):
+            curr_date = curr_date + timedelta(seconds=new_interval)
+            time_machine.coordinates.shift(new_interval)
+            self._run(sensor)
+            sensor_ti, dummy_ti = _get_tis()
+            assert sensor_ti.state == State.UP_FOR_RESCHEDULE
+            # verify another row in task_reschedule table
+            task_reschedules = task_reschedules_for_ti(sensor_ti)
+            assert len(task_reschedules) == _poke_count
+            old_interval = new_interval
+            new_interval = sensor._get_next_poke_interval(task_start_date, 
run_duration, _poke_count)
+            assert old_interval < new_interval  # actual test
+            assert task_reschedules[-1].start_date == curr_date
+            assert task_reschedules[-1].reschedule_date == curr_date + 
timedelta(seconds=new_interval)
+            assert dummy_ti.state == State.NONE
+
+        # last poke returns True and task succeeds
+        curr_date = curr_date + timedelta(seconds=new_interval)
+        time_machine.coordinates.shift(new_interval)
+        self._run(sensor)
+
+        sensor_ti, dummy_ti = _get_tis()
+        assert sensor_ti.state == State.SUCCESS
+        assert dummy_ti.state == State.NONE
+
     @pytest.mark.parametrize("mode", ["poke", "reschedule"])
     def test_should_include_ready_to_reschedule_dep(self, mode):
         sensor = DummySensor(task_id="a", return_value=True, mode=mode)

Reply via email to