Dev-iL commented on code in PR #61975:
URL: https://github.com/apache/airflow/pull/61975#discussion_r2810964569
##########
airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -921,6 +926,98 @@ def _iter_breadcrumbs() -> Iterator[dict[str, Any]]:
return TaskBreadcrumbsResponse(breadcrumbs=_iter_breadcrumbs())
+def _populate_task_group_map_index_context(
+ context: TIRunContext,
+ dag_id: str,
+ task_id: str,
+ map_index: int,
+ run_id: str,
+ session: SessionDep,
+ dag_bag: DagBagDep,
+) -> None:
+ """Populate task group map_index_template and expanded args on the
TIRunContext."""
+ try:
+ dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+ except HTTPException:
+ return
+
+ task = dag.task_dict.get(task_id)
+ if not task:
+ return
+
+ for mtg in task.iter_mapped_task_groups():
+ if not mtg.map_index_template:
+ continue
+
+ context.task_group_map_index_template = mtg.map_index_template
+ context.task_group_expanded_args = _resolve_task_group_expand_args(
+ mtg._expand_input, map_index, run_id, session
+ )
+ break
+
+
+def _resolve_task_group_expand_args(
+ expand_input: Any,
+ map_index: int,
+ run_id: str,
+ session: SessionDep,
+) -> dict[str, Any] | None:
+ """Resolve the expand_input for a specific map_index to get the expanded
arguments."""
+ from airflow.models.expandinput import SchedulerDictOfListsExpandInput,
SchedulerListOfDictsExpandInput
+ from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
+
+ if isinstance(expand_input, SchedulerDictOfListsExpandInput):
+ resolved: dict[str, Any] = {}
+ for key, value in expand_input.value.items():
+ if isinstance(value, SchedulerXComArg):
+ xcom_result = _resolve_xcom_arg_value(value, run_id, session)
+ if isinstance(xcom_result, list) and map_index <
len(xcom_result):
+ resolved[key] = xcom_result[map_index]
+ elif isinstance(value, (list, tuple)):
+ if map_index < len(value):
+ resolved[key] = value[map_index]
+ return resolved if resolved else None
+
+ if isinstance(expand_input, SchedulerListOfDictsExpandInput):
+ if isinstance(expand_input.value, (list, tuple)):
+ if map_index < len(expand_input.value):
+ item = expand_input.value[map_index]
+ if isinstance(item, dict):
+ return item
+ elif isinstance(expand_input.value, SchedulerXComArg):
+ xcom_result = _resolve_xcom_arg_value(expand_input.value, run_id,
session)
+ if isinstance(xcom_result, list) and map_index < len(xcom_result):
+ item = xcom_result[map_index]
+ if isinstance(item, dict):
+ return item
+
+ return None
Review Comment:
How about this?
```python
def _resolve_task_group_expand_args(
expand_input: Any,
map_index: int,
run_id: str,
session: SessionDep,
) -> dict[str, Any] | None:
"""Resolve the expand_input for a specific map_index to get the expanded
arguments."""
from airflow.models.expandinput import SchedulerDictOfListsExpandInput,
SchedulerListOfDictsExpandInput
from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
def _resolve_at_index(value: Any) -> Any | None:
"""Resolve a single value (list/tuple or XComArg) at the given
map_index."""
match value:
case SchedulerXComArg():
value = _resolve_xcom_arg_value(value, run_id, session)
case list() | tuple():
pass
case _:
return None
if isinstance(value, (list, tuple)) and map_index < len(value):
return value[map_index]
return None
match expand_input:
case SchedulerDictOfListsExpandInput(value=mapping):
resolved = {}
for key, val in mapping.items():
if (item := _resolve_at_index(val)) is not None:
resolved[key] = item
return resolved or None
case SchedulerListOfDictsExpandInput(value=val):
match _resolve_at_index(val):
case dict() as item:
return item
return None
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]