anishgirianish commented on code in PR #61975:
URL: https://github.com/apache/airflow/pull/61975#discussion_r2811202215
##########
airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -921,6 +926,98 @@ def _iter_breadcrumbs() -> Iterator[dict[str, Any]]:
return TaskBreadcrumbsResponse(breadcrumbs=_iter_breadcrumbs())
+def _populate_task_group_map_index_context(
+ context: TIRunContext,
+ dag_id: str,
+ task_id: str,
+ map_index: int,
+ run_id: str,
+ session: SessionDep,
+ dag_bag: DagBagDep,
+) -> None:
+ """Populate task group map_index_template and expanded args on the
TIRunContext."""
+ try:
+ dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+ except HTTPException:
+ return
+
+ task = dag.task_dict.get(task_id)
+ if not task:
+ return
+
+ for mtg in task.iter_mapped_task_groups():
+ if not mtg.map_index_template:
+ continue
+
+ context.task_group_map_index_template = mtg.map_index_template
+ context.task_group_expanded_args = _resolve_task_group_expand_args(
+ mtg._expand_input, map_index, run_id, session
+ )
+ break
+
+
+def _resolve_task_group_expand_args(
+ expand_input: Any,
+ map_index: int,
+ run_id: str,
+ session: SessionDep,
+) -> dict[str, Any] | None:
+ """Resolve the expand_input for a specific map_index to get the expanded
arguments."""
+ from airflow.models.expandinput import SchedulerDictOfListsExpandInput,
SchedulerListOfDictsExpandInput
+ from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
+
+ if isinstance(expand_input, SchedulerDictOfListsExpandInput):
+ resolved: dict[str, Any] = {}
+ for key, value in expand_input.value.items():
+ if isinstance(value, SchedulerXComArg):
+ xcom_result = _resolve_xcom_arg_value(value, run_id, session)
+ if isinstance(xcom_result, list) and map_index <
len(xcom_result):
+ resolved[key] = xcom_result[map_index]
+ elif isinstance(value, (list, tuple)):
+ if map_index < len(value):
+ resolved[key] = value[map_index]
+ return resolved if resolved else None
+
+ if isinstance(expand_input, SchedulerListOfDictsExpandInput):
+ if isinstance(expand_input.value, (list, tuple)):
+ if map_index < len(expand_input.value):
+ item = expand_input.value[map_index]
+ if isinstance(item, dict):
+ return item
+ elif isinstance(expand_input.value, SchedulerXComArg):
+ xcom_result = _resolve_xcom_arg_value(expand_input.value, run_id,
session)
+ if isinstance(xcom_result, list) and map_index < len(xcom_result):
+ item = xcom_result[map_index]
+ if isinstance(item, dict):
+ return item
+
+ return None
+
+
+def _resolve_xcom_arg_value(xcom_arg: Any, run_id: str, session: SessionDep)
-> Any:
+ """Resolve a SchedulerXComArg to its actual value via XCom query."""
+ refs = list(xcom_arg.iter_references())
+ if not refs:
+ return None
+ operator, key = refs[0]
+
+ xcom_value = session.scalar(
+ select(XComModel.value).where(
+ XComModel.dag_id == operator.dag_id,
+ XComModel.task_id == operator.task_id,
+ XComModel.run_id == run_id,
+ XComModel.key == key,
+ XComModel.map_index == -1,
Review Comment:
Great question! From what I can tell, the upstream producing the expand
input is unmapped; it returns a list that drives the mapping. If it were
mapped, it would go through expand_kwargs instead, so it wouldn't be a
SchedulerXComArg here. But I'd love to hear if you've seen a case where this
breaks,
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]