This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-3-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 4a8d13997ff95c7135e726290aa74ba798967a30 Author: Ephraim Anierobi <[email protected]> AuthorDate: Thu Aug 11 08:30:33 2022 +0100 Fix mapped sensor with reschedule mode (#25594) There are two issues with mapped sensor with `reschedule` mode. First, the reschedule table is being populated with a default map_index of -1 even when the map_index is not -1. Secondly, MappedOperator does not have the `ReadyToReschedule` dependency. This PR is an attempt to fix this Co-authored-by: Tzu-ping Chung <[email protected]> (cherry picked from commit 5f3733ea310b53a0a90c660dc94dd6e1ad5755b7) --- airflow/models/taskinstance.py | 1 + airflow/models/taskreschedule.py | 1 + airflow/serialization/serialized_objects.py | 2 +- airflow/ti_deps/deps/ready_to_reschedule.py | 11 +- tests/models/test_taskinstance.py | 217 +++++++++++++++++++++ tests/serialization/test_dag_serialization.py | 28 +++ tests/ti_deps/deps/test_ready_to_reschedule_dep.py | 81 +++++++- 7 files changed, 337 insertions(+), 4 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index c369e84e47..4b7f9ebf7b 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1865,6 +1865,7 @@ class TaskInstance(Base, LoggingMixin): actual_start_date, self.end_date, reschedule_exception.reschedule_date, + self.map_index, ) ) diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py index 132554d8d1..13978fe186 100644 --- a/airflow/models/taskreschedule.py +++ b/airflow/models/taskreschedule.py @@ -112,6 +112,7 @@ class TaskReschedule(Base): TR.dag_id == task_instance.dag_id, TR.task_id == task_instance.task_id, TR.run_id == task_instance.run_id, + TR.map_index == task_instance.map_index, TR.try_number == try_number, ) if descending: diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index dd3dc4404e..5786f9231c 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -590,7 +590,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): @classmethod def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]: - serialized_op = cls._serialize_node(op, include_deps=op.deps is MappedOperator.deps_for(BaseOperator)) + serialized_op = cls._serialize_node(op, include_deps=op.deps != MappedOperator.deps_for(BaseOperator)) # Handle mapped_kwargs and mapped_op_kwargs. serialized_op[op._expansion_kwargs_attr] = cls._serialize(op._get_expansion_kwargs()) diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py b/airflow/ti_deps/deps/ready_to_reschedule.py index 9086822cea..88219bb401 100644 --- a/airflow/ti_deps/deps/ready_to_reschedule.py +++ b/airflow/ti_deps/deps/ready_to_reschedule.py @@ -40,7 +40,11 @@ class ReadyToRescheduleDep(BaseTIDep): considered as passed. This dependency fails if the latest reschedule request's reschedule date is still in future. """ - if not getattr(ti.task, "reschedule", False): + is_mapped = ti.task.is_mapped + if not is_mapped and not getattr(ti.task, "reschedule", False): + # Mapped sensors don't have the reschedule property (it can only + # be calculated after unmapping), so we don't check them here. + # They are handled below by checking TaskReschedule instead. yield self._passing_status(reason="Task is not in reschedule mode.") return @@ -62,6 +66,11 @@ class ReadyToRescheduleDep(BaseTIDep): .first() ) if not task_reschedule: + # Because mapped sensors don't have the reschedule property, here's the last resort + # and we need a slightly different passing reason + if is_mapped: + yield self._passing_status(reason="The task is mapped and not in reschedule mode") + return yield self._passing_status(reason="There is no reschedule request for this task instance.") return diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 1db5542904..74ef87489f 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -811,6 +811,174 @@ class TestTaskInstance: done, fail = True, False run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0) + def test_mapped_reschedule_handling(self, dag_maker): + """ + Test that mapped task reschedules are handled properly + """ + # Return values of the python sensor callable, modified during tests + done = False + fail = False + + def func(): + if fail: + raise AirflowException() + return done + + with dag_maker(dag_id='test_reschedule_handling') as dag: + + task = PythonSensor.partial( + task_id='test_reschedule_handling_sensor', + mode='reschedule', + python_callable=func, + retries=1, + retry_delay=datetime.timedelta(seconds=0), + ).expand(poke_interval=[0]) + + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + + ti.task = task + assert ti._try_number == 0 + assert ti.try_number == 1 + + def run_ti_and_assert( + run_date, + expected_start_date, + expected_end_date, + expected_duration, + expected_state, + expected_try_number, + expected_task_reschedule_count, + ): + ti.refresh_from_task(task) + with freeze_time(run_date): + try: + ti.run() + except AirflowException: + if not fail: + raise + ti.refresh_from_db() + assert ti.state == expected_state + assert ti._try_number == expected_try_number + assert ti.try_number == expected_try_number + 1 + assert ti.start_date == expected_start_date + assert ti.end_date == expected_end_date + assert ti.duration == expected_duration + trs = TaskReschedule.find_for_task_instance(ti) + assert len(trs) == expected_task_reschedule_count + + date1 = timezone.utcnow() + date2 = date1 + datetime.timedelta(minutes=1) + date3 = date2 + datetime.timedelta(minutes=1) + date4 = date3 + datetime.timedelta(minutes=1) + + # Run with multiple reschedules. + # During reschedule the try number remains the same, but each reschedule is recorded. + # The start date is expected to remain the initial date, hence the duration increases. + # When finished the try number is incremented and there is no reschedule expected + # for this try. + + done, fail = False, False + run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1) + + done, fail = False, False + run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RESCHEDULE, 0, 2) + + done, fail = False, False + run_ti_and_assert(date3, date1, date3, 120, State.UP_FOR_RESCHEDULE, 0, 3) + + done, fail = True, False + run_ti_and_assert(date4, date1, date4, 180, State.SUCCESS, 1, 0) + + # Clear the task instance. + dag.clear() + ti.refresh_from_db() + assert ti.state == State.NONE + assert ti._try_number == 1 + + # Run again after clearing with reschedules and a retry. + # The retry increments the try number, and for that try no reschedule is expected. + # After the retry the start date is reset, hence the duration is also reset. + + done, fail = False, False + run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1) + + done, fail = False, True + run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 2, 0) + + done, fail = False, False + run_ti_and_assert(date3, date3, date3, 0, State.UP_FOR_RESCHEDULE, 2, 1) + + done, fail = True, False + run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0) + + @pytest.mark.usefixtures('test_pool') + def test_mapped_task_reschedule_handling_clear_reschedules(self, dag_maker): + """ + Test that mapped task reschedules clearing are handled properly + """ + # Return values of the python sensor callable, modified during tests + done = False + fail = False + + def func(): + if fail: + raise AirflowException() + return done + + with dag_maker(dag_id='test_reschedule_handling') as dag: + task = PythonSensor.partial( + task_id='test_reschedule_handling_sensor', + mode='reschedule', + python_callable=func, + retries=1, + retry_delay=datetime.timedelta(seconds=0), + pool='test_pool', + ).expand(poke_interval=[0]) + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + ti.task = task + assert ti._try_number == 0 + assert ti.try_number == 1 + + def run_ti_and_assert( + run_date, + expected_start_date, + expected_end_date, + expected_duration, + expected_state, + expected_try_number, + expected_task_reschedule_count, + ): + ti.refresh_from_task(task) + with freeze_time(run_date): + try: + ti.run() + except AirflowException: + if not fail: + raise + ti.refresh_from_db() + assert ti.state == expected_state + assert ti._try_number == expected_try_number + assert ti.try_number == expected_try_number + 1 + assert ti.start_date == expected_start_date + assert ti.end_date == expected_end_date + assert ti.duration == expected_duration + trs = TaskReschedule.find_for_task_instance(ti) + assert len(trs) == expected_task_reschedule_count + + date1 = timezone.utcnow() + + done, fail = False, False + run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1) + + # Clear the task instance. + dag.clear() + ti.refresh_from_db() + assert ti.state == State.NONE + assert ti._try_number == 0 + # Check that reschedules for ti have also been cleared. + trs = TaskReschedule.find_for_task_instance(ti) + assert not trs + @pytest.mark.usefixtures('test_pool') def test_reschedule_handling_clear_reschedules(self, dag_maker): """ @@ -2412,6 +2580,55 @@ def test_sensor_timeout(mode, retries, dag_maker): assert ti.state == State.FAILED [email protected]("mode", ["poke", "reschedule"]) [email protected]("retries", [0, 1]) +def test_mapped_sensor_timeout(mode, retries, dag_maker): + """ + Test that AirflowSensorTimeout does not cause mapped sensor to retry. + """ + + def timeout(): + raise AirflowSensorTimeout + + mock_on_failure = mock.MagicMock() + with dag_maker(dag_id=f'test_sensor_timeout_{mode}_{retries}'): + PythonSensor.partial( + task_id='test_raise_sensor_timeout', + python_callable=timeout, + on_failure_callback=mock_on_failure, + retries=retries, + ).expand(mode=[mode]) + ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] + + with pytest.raises(AirflowSensorTimeout): + ti.run() + + assert mock_on_failure.called + assert ti.state == State.FAILED + + [email protected]("mode", ["poke", "reschedule"]) [email protected]("retries", [0, 1]) +def test_mapped_sensor_works(mode, retries, dag_maker): + """ + Test that mapped sensors reaches success state. + """ + + def timeout(ti): + return 1 + + with dag_maker(dag_id=f'test_sensor_timeout_{mode}_{retries}'): + PythonSensor.partial( + task_id='test_raise_sensor_timeout', + python_callable=timeout, + retries=retries, + ).expand(mode=[mode]) + ti = dag_maker.create_dagrun().task_instances[0] + + ti.run() + assert ti.state == State.SUCCESS + + class TestTaskInstanceRecordTaskMapXComPush: """Test TI.xcom_push() correctly records return values for task-mapping.""" diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 7d6a43e933..6f5b9b49eb 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -43,6 +43,7 @@ from airflow.models.param import Param, ParamsDict from airflow.models.xcom import XCOM_RETURN_KEY, XCom from airflow.operators.bash import BashOperator from airflow.security import permissions +from airflow.sensors.bash import BashSensor from airflow.serialization.json_schema import load_dag_schema_dict from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.ti_deps.deps.base_ti_dep import BaseTIDep @@ -1228,6 +1229,7 @@ class TestStringifiedDAGs: task1 >> task2 serialize_op = SerializedBaseOperator.serialize_operator(dag.task_dict["task1"]) + deps = serialize_op["deps"] assert deps == [ 'airflow.ti_deps.deps.not_in_retry_period_dep.NotInRetryPeriodDep', @@ -1465,6 +1467,21 @@ class TestStringifiedDAGs: assert serialized_op.reschedule == (mode == "reschedule") assert op.deps == serialized_op.deps + @pytest.mark.parametrize("mode", ["poke", "reschedule"]) + def test_serialize_mapped_sensor_has_reschedule_dep(self, mode): + from airflow.sensors.base import BaseSensorOperator + + class DummySensor(BaseSensorOperator): + def poke(self, context: Context): + return False + + op = DummySensor.partial(task_id='dummy', mode=mode).expand(poke_interval=[23]) + + blob = SerializedBaseOperator.serialize_mapped_operator(op) + assert "deps" in blob + + assert 'airflow.ti_deps.deps.ready_to_reschedule.ReadyToRescheduleDep' in blob['deps'] + @pytest.mark.parametrize( "passed_success_callback, expected_value", [ @@ -1778,6 +1795,17 @@ def test_mapped_operator_deserialized_unmap(): assert deserialize(serialize(mapped)).unmap() == deserialize(serialize(normal)) +def test_sensor_expand_deserialized_unmap(): + """Unmap a deserialized mapped sensor should be similar to deserializing a non-mapped sensor""" + normal = BashSensor(task_id='a', bash_command=[1, 2], mode='reschedule') + mapped = BashSensor.partial(task_id='a', mode='reschedule').expand(bash_command=[1, 2]) + + serialize = SerializedBaseOperator._serialize + + deserialize = SerializedBaseOperator.deserialize_operator + assert deserialize(serialize(mapped)).unmap(None) == deserialize(serialize(normal)) + + def test_task_resources_serde(): """ Test task resources serialization/deserialization. diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py index 99416bbbc8..2ab8c539dc 100644 --- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py +++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py @@ -31,12 +31,30 @@ from airflow.utils.timezone import utcnow class TestNotInReschedulePeriodDep(unittest.TestCase): def _get_task_instance(self, state): dag = DAG('test_dag') - task = Mock(dag=dag, reschedule=True) + task = Mock(dag=dag, reschedule=True, is_mapped=False) ti = TaskInstance(task=task, state=state, run_id=None) return ti def _get_task_reschedule(self, reschedule_date): - task = Mock(dag_id='test_dag', task_id='test_task') + task = Mock(dag_id='test_dag', task_id='test_task', is_mapped=False) + reschedule = TaskReschedule( + task=task, + run_id=None, + try_number=None, + start_date=reschedule_date, + end_date=reschedule_date, + reschedule_date=reschedule_date, + ) + return reschedule + + def _get_mapped_task_instance(self, state): + dag = DAG('test_dag') + task = Mock(dag=dag, reschedule=True, is_mapped=True) + ti = TaskInstance(task=task, state=state, run_id=None) + return ti + + def _get_mapped_task_reschedule(self, reschedule_date): + task = Mock(dag_id='test_dag', task_id='test_task', is_mapped=True) reschedule = TaskReschedule( task=task, run_id=None, @@ -103,3 +121,62 @@ class TestNotInReschedulePeriodDep(unittest.TestCase): ][-1] ti = self._get_task_instance(State.UP_FOR_RESCHEDULE) assert not ReadyToRescheduleDep().is_met(ti=ti) + + def test_mapped_task_should_pass_if_ignore_in_reschedule_period_is_set(self): + ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE) + dep_context = DepContext(ignore_in_reschedule_period=True) + assert ReadyToRescheduleDep().is_met(ti=ti, dep_context=dep_context) + + @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance') + def test_mapped_task_should_pass_if_not_reschedule_mode(self, mock_query_for_task_instance): + mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [] + ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE) + del ti.task.reschedule + assert ReadyToRescheduleDep().is_met(ti=ti) + + def test_mapped_task_should_pass_if_not_in_none_state(self): + ti = self._get_mapped_task_instance(State.UP_FOR_RETRY) + assert ReadyToRescheduleDep().is_met(ti=ti) + + @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance') + def test_mapped_should_pass_if_no_reschedule_record_exists(self, mock_query_for_task_instance): + mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [] + ti = self._get_mapped_task_instance(State.NONE) + assert ReadyToRescheduleDep().is_met(ti=ti) + + @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance') + def test_mapped_should_pass_after_reschedule_date_one(self, mock_query_for_task_instance): + mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = ( + self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=1)) + ) + ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE) + assert ReadyToRescheduleDep().is_met(ti=ti) + + @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance') + def test_mapped_task_should_pass_after_reschedule_date_multiple(self, mock_query_for_task_instance): + mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [ + self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=21)), + self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=11)), + self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=1)), + ][-1] + ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE) + assert ReadyToRescheduleDep().is_met(ti=ti) + + @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance') + def test_mapped_task_should_fail_before_reschedule_date_one(self, mock_query_for_task_instance): + mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = ( + self._get_mapped_task_reschedule(utcnow() + timedelta(minutes=1)) + ) + + ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE) + assert not ReadyToRescheduleDep().is_met(ti=ti) + + @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance') + def test_mapped_task_should_fail_before_reschedule_date_multiple(self, mock_query_for_task_instance): + mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [ + self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=19)), + self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=9)), + self._get_mapped_task_reschedule(utcnow() + timedelta(minutes=1)), + ][-1] + ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE) + assert not ReadyToRescheduleDep().is_met(ti=ti)
