This is an automated email from the ASF dual-hosted git repository. jedcunningham pushed a commit to branch v2-9-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 3487fccfe27d959f6ed1b407cfb613635f88a839 Author: Tamara Janina Fingerlin <[email protected]> AuthorDate: Mon Apr 15 13:05:31 2024 +0200 Bugfix: Move rendering of `map_index_template` so it renders for failed tasks as long as it was defined before the point of failure (#38902) Co-authored-by: Tzu-ping Chung <[email protected]> (cherry picked from commit 456ec48d12be02ca7266f021a16e01abb5d4c5a3) --- airflow/models/taskinstance.py | 25 ++++++++++++++++++------- tests/models/test_mappedoperator.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 42849f88f1..88a495e9eb 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2658,19 +2658,30 @@ class TaskInstance(Base, LoggingMixin): previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session ) - # Execute the task + def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None: + """Render named map index if the DAG author defined map_index_template at the task level.""" + if jinja_env is None or (template := context.get("map_index_template")) is None: + return None + rendered_map_index = jinja_env.from_string(template).render(context) + log.debug("Map index rendered as %s", rendered_map_index) + return rendered_map_index + + # Execute the task. with set_current_context(context): - result = self._execute_task(context, task_orig) + try: + result = self._execute_task(context, task_orig) + except Exception: + # If the task failed, swallow rendering error so it doesn't mask the main error. + with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): + self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env) + raise + else: # If the task succeeded, render normally to let rendering error bubble up. + self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env) # Run post_execute callback # Is never MappedOperator at this point self.task.post_execute(context=context, result=result) # type: ignore[union-attr] - # DAG authors define map_index_template at the task level - if jinja_env is not None and (template := context.get("map_index_template")) is not None: - rendered_map_index = self.rendered_map_index = jinja_env.from_string(template).render(context) - self.log.info("Map index rendered as %s", rendered_map_index) - Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags) # Same metric with tagging Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type}) diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 64304cf306..0351fab05d 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -635,6 +635,33 @@ def _create_mapped_with_name_template_taskflow(*, task_id, map_names, template): return task1.expand(map_name=map_names) +def _create_named_map_index_renders_on_failure_classic(*, task_id, map_names, template): + class HasMapName(BaseOperator): + def __init__(self, *, map_name: str, **kwargs): + super().__init__(**kwargs) + self.map_name = map_name + + def execute(self, context): + context["map_name"] = self.map_name + raise AirflowSkipException("Imagine this task failed!") + + return HasMapName.partial(task_id=task_id, map_index_template=template).expand( + map_name=map_names, + ) + + +def _create_named_map_index_renders_on_failure_taskflow(*, task_id, map_names, template): + from airflow.operators.python import get_current_context + + @task(task_id=task_id, map_index_template=template) + def task1(map_name): + context = get_current_context() + context["map_name"] = map_name + raise AirflowSkipException("Imagine this task failed!") + + return task1.expand(map_name=map_names) + + @pytest.mark.parametrize( "template, expected_rendered_names", [ @@ -649,6 +676,8 @@ def _create_mapped_with_name_template_taskflow(*, task_id, map_names, template): [ pytest.param(_create_mapped_with_name_template_classic, id="classic"), pytest.param(_create_mapped_with_name_template_taskflow, id="taskflow"), + pytest.param(_create_named_map_index_renders_on_failure_classic, id="classic-failure"), + pytest.param(_create_named_map_index_renders_on_failure_taskflow, id="taskflow-failure"), ], ) def test_expand_mapped_task_instance_with_named_index(
