This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 456ec48d12 Bugfix: Move rendering of `map_index_template` so it
renders for failed tasks as long as it was defined before the point of failure
(#38902)
456ec48d12 is described below
commit 456ec48d12be02ca7266f021a16e01abb5d4c5a3
Author: Tamara Janina Fingerlin <[email protected]>
AuthorDate: Mon Apr 15 13:05:31 2024 +0200
Bugfix: Move rendering of `map_index_template` so it renders for failed
tasks as long as it was defined before the point of failure (#38902)
Co-authored-by: Tzu-ping Chung <[email protected]>
---
airflow/models/taskinstance.py | 25 ++++++++++++++++++-------
tests/models/test_mappedoperator.py | 29 +++++++++++++++++++++++++++++
2 files changed, 47 insertions(+), 7 deletions(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 2b8b935d78..8f9d71cfe7 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2731,18 +2731,29 @@ class TaskInstance(Base, LoggingMixin):
previous_state=TaskInstanceState.QUEUED, task_instance=self,
session=session
)
- # Execute the task
+ def _render_map_index(context: Context, *, jinja_env:
jinja2.Environment | None) -> str | None:
+ """Render named map index if the DAG author defined
map_index_template at the task level."""
+ if jinja_env is None or (template :=
context.get("map_index_template")) is None:
+ return None
+ rendered_map_index =
jinja_env.from_string(template).render(context)
+ log.debug("Map index rendered as %s", rendered_map_index)
+ return rendered_map_index
+
+ # Execute the task.
with set_current_context(context):
- result = self._execute_task(context, task_orig)
+ try:
+ result = self._execute_task(context, task_orig)
+ except Exception:
+ # If the task failed, swallow rendering error so it
doesn't mask the main error.
+ with contextlib.suppress(jinja2.TemplateSyntaxError,
jinja2.UndefinedError):
+ self.rendered_map_index = _render_map_index(context,
jinja_env=jinja_env)
+ raise
+ else: # If the task succeeded, render normally to let
rendering error bubble up.
+ self.rendered_map_index = _render_map_index(context,
jinja_env=jinja_env)
# Run post_execute callback
self.task.post_execute(context=context, result=result)
- # DAG authors define map_index_template at the task level
- if jinja_env is not None and (template :=
context.get("map_index_template")) is not None:
- rendered_map_index = self.rendered_map_index =
jinja_env.from_string(template).render(context)
- self.log.info("Map index rendered as %s", rendered_map_index)
-
Stats.incr(f"operator_successes_{self.task.task_type}",
tags=self.stats_tags)
# Same metric with tagging
Stats.incr("operator_successes", tags={**self.stats_tags, "task_type":
self.task.task_type})
diff --git a/tests/models/test_mappedoperator.py
b/tests/models/test_mappedoperator.py
index e80a629794..9f31652424 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -631,6 +631,33 @@ def _create_mapped_with_name_template_taskflow(*, task_id,
map_names, template):
return task1.expand(map_name=map_names)
+def _create_named_map_index_renders_on_failure_classic(*, task_id, map_names,
template):
+ class HasMapName(BaseOperator):
+ def __init__(self, *, map_name: str, **kwargs):
+ super().__init__(**kwargs)
+ self.map_name = map_name
+
+ def execute(self, context):
+ context["map_name"] = self.map_name
+ raise AirflowSkipException("Imagine this task failed!")
+
+ return HasMapName.partial(task_id=task_id,
map_index_template=template).expand(
+ map_name=map_names,
+ )
+
+
+def _create_named_map_index_renders_on_failure_taskflow(*, task_id, map_names,
template):
+ from airflow.operators.python import get_current_context
+
+ @task(task_id=task_id, map_index_template=template)
+ def task1(map_name):
+ context = get_current_context()
+ context["map_name"] = map_name
+ raise AirflowSkipException("Imagine this task failed!")
+
+ return task1.expand(map_name=map_names)
+
+
@pytest.mark.parametrize(
"template, expected_rendered_names",
[
@@ -645,6 +672,8 @@ def _create_mapped_with_name_template_taskflow(*, task_id,
map_names, template):
[
pytest.param(_create_mapped_with_name_template_classic, id="classic"),
pytest.param(_create_mapped_with_name_template_taskflow,
id="taskflow"),
+ pytest.param(_create_named_map_index_renders_on_failure_classic,
id="classic-failure"),
+ pytest.param(_create_named_map_index_renders_on_failure_taskflow,
id="taskflow-failure"),
],
)
def test_expand_mapped_task_instance_with_named_index(