bugraoz93 commented on code in PR #53216:
URL: https://github.com/apache/airflow/pull/53216#discussion_r2204741857
##########
airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py:
##########
@@ -17,29 +17,284 @@
from __future__ import annotations
+import contextlib
from collections import Counter
from collections.abc import Iterable
+from typing import Any
+from uuid import UUID
import structlog
+from sqlalchemy import select
+from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.parameters import state_priority
+from airflow.api_fastapi.core_api.datamodels.ui.grid import
GridTaskInstanceSummary
+from airflow.api_fastapi.core_api.datamodels.ui.structure import
StructureDataResponse
+from airflow.models.baseoperator import BaseOperator as DBBaseOperator
+from airflow.models.dag import DAG
+from airflow.models.dag_version import DagVersion
+from airflow.models.expandinput import NotFullyPopulated
from airflow.models.taskmap import TaskMap
+from airflow.sdk import BaseOperator
+from airflow.sdk.definitions._internal.abstractoperator import NotMapped
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
-from airflow.serialization.serialized_objects import SerializedBaseOperator
-from airflow.utils.task_group import get_task_group_children_getter
+from airflow.serialization.serialized_objects import SerializedBaseOperator,
SerializedDAG
+from airflow.utils.state import TaskInstanceState
+from airflow.utils.task_group import get_task_group_children_getter,
task_group_to_dict_grid
log = structlog.get_logger(logger_name=__name__)
+def get_task_group_map(dag: DAG) -> dict[str, dict[str, Any]]:
Review Comment:
This as well
##########
airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py:
##########
@@ -17,29 +17,284 @@
from __future__ import annotations
+import contextlib
from collections import Counter
from collections.abc import Iterable
+from typing import Any
+from uuid import UUID
import structlog
+from sqlalchemy import select
+from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.parameters import state_priority
+from airflow.api_fastapi.core_api.datamodels.ui.grid import
GridTaskInstanceSummary
+from airflow.api_fastapi.core_api.datamodels.ui.structure import
StructureDataResponse
+from airflow.models.baseoperator import BaseOperator as DBBaseOperator
+from airflow.models.dag import DAG
+from airflow.models.dag_version import DagVersion
+from airflow.models.expandinput import NotFullyPopulated
from airflow.models.taskmap import TaskMap
+from airflow.sdk import BaseOperator
+from airflow.sdk.definitions._internal.abstractoperator import NotMapped
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
-from airflow.serialization.serialized_objects import SerializedBaseOperator
-from airflow.utils.task_group import get_task_group_children_getter
+from airflow.serialization.serialized_objects import SerializedBaseOperator,
SerializedDAG
+from airflow.utils.state import TaskInstanceState
+from airflow.utils.task_group import get_task_group_children_getter,
task_group_to_dict_grid
log = structlog.get_logger(logger_name=__name__)
+def get_task_group_map(dag: DAG) -> dict[str, dict[str, Any]]:
+ """
+ Get the Task Group Map for the DAG.
+
+ :param dag: DAG
+
+ :return: Task Group Map
+ """
+ task_nodes: dict[str, dict[str, Any]] = {}
+
+ def _is_task_node_mapped_task_group(task_node: BaseOperator |
MappedTaskGroup | TaskMap | None) -> bool:
+ """Check if the Task Node is a Mapped Task Group."""
+ return type(task_node) is MappedTaskGroup
+
+ def _append_child_task_count_to_parent(
+ child_task_count: int | MappedTaskGroup | TaskMap | MappedOperator |
None,
+ parent_node: BaseOperator | MappedTaskGroup | TaskMap | None,
+ ):
+ """
+ Append the Child Task Count to the Parent.
+
+ This method should only be used for Mapped Models.
+ """
+ if isinstance(parent_node, TaskGroup):
+ # Remove the regular task counted in parent_node
+ task_nodes[parent_node.node_id]["task_count"].append(-1)
+ # Add the mapped task to the parent_node
+
task_nodes[parent_node.node_id]["task_count"].append(child_task_count)
+
+ def _fill_task_group_map(
+ task_node: BaseOperator | MappedTaskGroup | TaskMap | None,
+ parent_node: BaseOperator | MappedTaskGroup | TaskMap | None,
+ ) -> None:
+ """Recursively fill the Task Group Map."""
+ if task_node is None:
+ return
+
+ if isinstance(task_node, MappedOperator):
+ task_nodes[task_node.node_id] = {
+ "is_group": False,
+ "parent_id": parent_node.node_id if parent_node else None,
+ "task_count": [task_node],
+ }
+ # Add the Task Count to the Parent Node because parent node is a
Task Group
+ _append_child_task_count_to_parent(child_task_count=task_node,
parent_node=parent_node)
+ return
+
+ if isinstance(task_node, TaskGroup):
+ task_count = task_node if
_is_task_node_mapped_task_group(task_node) else len(task_node.children)
+ task_nodes[task_node.node_id] = {
+ "is_group": True,
+ "parent_id": parent_node.node_id if parent_node else None,
+ "task_count": [task_count],
+ }
+ for child in get_task_group_children_getter()(task_node):
+ _fill_task_group_map(task_node=child, parent_node=task_node)
+ return
+
+ if isinstance(task_node, BaseOperator):
+ task_nodes[task_node.task_id] = {
+ "is_group": False,
+ "parent_id": parent_node.node_id if parent_node else None,
+ "task_count": task_nodes[parent_node.node_id]["task_count"]
+ if _is_task_node_mapped_task_group(parent_node) and parent_node
+ else [1],
+ }
+ # No Need to Add the Task Count to the Parent Node, these are
already counted in Add the Parent
+ return
+
+ for node in [child for child in
get_task_group_children_getter()(dag.task_group)]:
+ _fill_task_group_map(task_node=node, parent_node=None)
+
+ return task_nodes
+
+
+def get_child_task_map(parent_task_id: str, task_node_map: dict[str, dict[str,
Any]]):
+ """Get the Child Task Map for the Parent Task ID."""
+ return [task_id for task_id, task_map in task_node_map.items() if
task_map["parent_id"] == parent_task_id]
+
+
+def _count_tis(node: int | MappedTaskGroup | MappedOperator, run_id: str,
session: SessionDep) -> int:
+ if not isinstance(node, (MappedTaskGroup, MappedOperator)):
+ return node
+ with contextlib.suppress(NotFullyPopulated, NotMapped):
+ return DBBaseOperator.get_mapped_ti_count(node, run_id=run_id,
session=session)
+ # If the downstream is not actually mapped, or we don't have information to
+ # determine the length yet, simply return 1 to represent the stand-in ti.
+ return 1
+
+
+def fill_task_instance_summaries(
Review Comment:
This is still from old method
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]