This is an automated email from the ASF dual-hosted git repository. utkarsharma pushed a commit to branch v2-9-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 67cb21ac1801d0cf4bbab64085f3a0136190c006 Author: Sanket <[email protected]> AuthorDate: Sun Jun 16 23:45:58 2024 +0530 Exponential Backoff Not Functioning in BaseSensorOperator Reschedule Mode (#39823) * Fix: Increment try_number/poke_count in BaseSensorOperator for correct exponential backoff in reschedule mode. * exponential backoff handling for reschedule mode --------- Co-authored-by: sanket2000 <[email protected]> (cherry picked from commit 841b28cccbb32b682333dc9d27b1b5f04fc495ab) --- airflow/sensors/base.py | 31 ++++++++++++++++++++ newsfragments/39823.bugfix.rst | 1 + tests/sensors/test_base.py | 65 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+) diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index a980484672..0f8ce5c9f8 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -312,6 +312,37 @@ class BaseSensorOperator(BaseOperator, SkipMixin): if not self.exponential_backoff: return self.poke_interval + if self.reschedule: + # Calculate elapsed time since the sensor started + elapsed_time = run_duration() + + # Initialize variables for the simulation + cumulative_time: float = 0.0 + estimated_poke_count: int = 0 + + while cumulative_time <= elapsed_time: + estimated_poke_count += 1 + # Calculate min_backoff for the current try number + min_backoff = max(int(self.poke_interval * (2 ** (estimated_poke_count - 2))), 1) + + # Calculate the jitter + run_hash = int( + hashlib.sha1( + f"{self.dag_id}#{self.task_id}#{started_at}#{estimated_poke_count}".encode() + ).hexdigest(), + 16, + ) + modded_hash = min_backoff + run_hash % min_backoff + + # Calculate the jitter, which is used to prevent multiple sensors simultaneously poking + interval_with_jitter = min(modded_hash, timedelta.max.total_seconds() - 1) + + # Add the interval to the cumulative time + cumulative_time += interval_with_jitter + + # Now we have an estimated_poke_count based on the elapsed time + poke_count = estimated_poke_count or poke_count + # The value of min_backoff should always be greater than or equal to 1. min_backoff = max(int(self.poke_interval * (2 ** (try_number - 2))), 1) diff --git a/newsfragments/39823.bugfix.rst b/newsfragments/39823.bugfix.rst new file mode 100644 index 0000000000..7a774258a4 --- /dev/null +++ b/newsfragments/39823.bugfix.rst @@ -0,0 +1 @@ +Fixed ``BaseSensorOperator`` with exponential backoff and reschedule mode by estimating try number based on ``run_duration``; previously, sensors had a fixed reschedule interval. diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py index e3afcd3564..c4406ad643 100644 --- a/tests/sensors/test_base.py +++ b/tests/sensors/test_base.py @@ -414,6 +414,71 @@ class TestBaseSensor: if ti.task_id == DUMMY_OP: assert ti.state == State.NONE + def test_ok_with_reschedule_and_exponential_backoff( + self, make_sensor, time_machine, task_reschedules_for_ti, session + ): + sensor, dr = make_sensor( + return_value=None, + poke_interval=10, + timeout=36000, + mode="reschedule", + exponential_backoff=True, + ) + + def _get_tis(): + tis = dr.get_task_instances(session=session) + assert len(tis) == 2 + yield next(x for x in tis if x.task_id == SENSOR_OP) + yield next(x for x in tis if x.task_id == DUMMY_OP) + + false_count = 10 + sensor.poke = Mock(side_effect=[False] * false_count + [True]) + + task_start_date = timezone.utcnow() + + time_machine.move_to(task_start_date, tick=False) + curr_date = task_start_date + + def run_duration(): + return (timezone.utcnow() - task_start_date).total_seconds() + + new_interval = 0 + + sensor_ti, dummy_ti = _get_tis() + assert dummy_ti.state == State.NONE + assert sensor_ti.state == State.NONE + + # ordinarily the scheduler does this + sensor_ti.state = State.SCHEDULED + sensor_ti.try_number += 1 # first TI run + session.commit() + + # loop poke returns false + for _poke_count in range(1, false_count + 1): + curr_date = curr_date + timedelta(seconds=new_interval) + time_machine.coordinates.shift(new_interval) + self._run(sensor) + sensor_ti, dummy_ti = _get_tis() + assert sensor_ti.state == State.UP_FOR_RESCHEDULE + # verify another row in task_reschedule table + task_reschedules = task_reschedules_for_ti(sensor_ti) + assert len(task_reschedules) == _poke_count + old_interval = new_interval + new_interval = sensor._get_next_poke_interval(task_start_date, run_duration, _poke_count) + assert old_interval < new_interval # actual test + assert task_reschedules[-1].start_date == curr_date + assert task_reschedules[-1].reschedule_date == curr_date + timedelta(seconds=new_interval) + assert dummy_ti.state == State.NONE + + # last poke returns True and task succeeds + curr_date = curr_date + timedelta(seconds=new_interval) + time_machine.coordinates.shift(new_interval) + self._run(sensor) + + sensor_ti, dummy_ti = _get_tis() + assert sensor_ti.state == State.SUCCESS + assert dummy_ti.state == State.NONE + @pytest.mark.parametrize("mode", ["poke", "reschedule"]) def test_should_include_ready_to_reschedule_dep(self, mode): sensor = DummySensor(task_id="a", return_value=True, mode=mode)
