This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-9-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3487fccfe27d959f6ed1b407cfb613635f88a839
Author: Tamara Janina Fingerlin <[email protected]>
AuthorDate: Mon Apr 15 13:05:31 2024 +0200

    Bugfix: Move rendering of `map_index_template` so it renders for failed 
tasks as long as it was defined before the point of failure (#38902)
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    (cherry picked from commit 456ec48d12be02ca7266f021a16e01abb5d4c5a3)
---
 airflow/models/taskinstance.py      | 25 ++++++++++++++++++-------
 tests/models/test_mappedoperator.py | 29 +++++++++++++++++++++++++++++
 2 files changed, 47 insertions(+), 7 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 42849f88f1..88a495e9eb 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2658,19 +2658,30 @@ class TaskInstance(Base, LoggingMixin):
                 previous_state=TaskInstanceState.QUEUED, task_instance=self, 
session=session
             )
 
-            # Execute the task
+            def _render_map_index(context: Context, *, jinja_env: 
jinja2.Environment | None) -> str | None:
+                """Render named map index if the DAG author defined 
map_index_template at the task level."""
+                if jinja_env is None or (template := 
context.get("map_index_template")) is None:
+                    return None
+                rendered_map_index = 
jinja_env.from_string(template).render(context)
+                log.debug("Map index rendered as %s", rendered_map_index)
+                return rendered_map_index
+
+            # Execute the task.
             with set_current_context(context):
-                result = self._execute_task(context, task_orig)
+                try:
+                    result = self._execute_task(context, task_orig)
+                except Exception:
+                    # If the task failed, swallow rendering error so it 
doesn't mask the main error.
+                    with contextlib.suppress(jinja2.TemplateSyntaxError, 
jinja2.UndefinedError):
+                        self.rendered_map_index = _render_map_index(context, 
jinja_env=jinja_env)
+                    raise
+                else:  # If the task succeeded, render normally to let 
rendering error bubble up.
+                    self.rendered_map_index = _render_map_index(context, 
jinja_env=jinja_env)
 
             # Run post_execute callback
             # Is never MappedOperator at this point
             self.task.post_execute(context=context, result=result)  # type: 
ignore[union-attr]
 
-            # DAG authors define map_index_template at the task level
-            if jinja_env is not None and (template := 
context.get("map_index_template")) is not None:
-                rendered_map_index = self.rendered_map_index = 
jinja_env.from_string(template).render(context)
-                self.log.info("Map index rendered as %s", rendered_map_index)
-
         Stats.incr(f"operator_successes_{self.task.task_type}", 
tags=self.stats_tags)
         # Same metric with tagging
         Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": 
self.task.task_type})
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index 64304cf306..0351fab05d 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -635,6 +635,33 @@ def _create_mapped_with_name_template_taskflow(*, task_id, 
map_names, template):
     return task1.expand(map_name=map_names)
 
 
+def _create_named_map_index_renders_on_failure_classic(*, task_id, map_names, 
template):
+    class HasMapName(BaseOperator):
+        def __init__(self, *, map_name: str, **kwargs):
+            super().__init__(**kwargs)
+            self.map_name = map_name
+
+        def execute(self, context):
+            context["map_name"] = self.map_name
+            raise AirflowSkipException("Imagine this task failed!")
+
+    return HasMapName.partial(task_id=task_id, 
map_index_template=template).expand(
+        map_name=map_names,
+    )
+
+
+def _create_named_map_index_renders_on_failure_taskflow(*, task_id, map_names, 
template):
+    from airflow.operators.python import get_current_context
+
+    @task(task_id=task_id, map_index_template=template)
+    def task1(map_name):
+        context = get_current_context()
+        context["map_name"] = map_name
+        raise AirflowSkipException("Imagine this task failed!")
+
+    return task1.expand(map_name=map_names)
+
+
 @pytest.mark.parametrize(
     "template, expected_rendered_names",
     [
@@ -649,6 +676,8 @@ def _create_mapped_with_name_template_taskflow(*, task_id, 
map_names, template):
     [
         pytest.param(_create_mapped_with_name_template_classic, id="classic"),
         pytest.param(_create_mapped_with_name_template_taskflow, 
id="taskflow"),
+        pytest.param(_create_named_map_index_renders_on_failure_classic, 
id="classic-failure"),
+        pytest.param(_create_named_map_index_renders_on_failure_taskflow, 
id="taskflow-failure"),
     ],
 )
 def test_expand_mapped_task_instance_with_named_index(

Reply via email to