dabla commented on code in PR #68299:
URL: https://github.com/apache/airflow/pull/68299#discussion_r3388719038


##########
task-sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -529,6 +529,123 @@ def xcom_pull(
 
         return xcoms
 
+    async def axcom_pull(
+        self,
+        task_ids: str | Iterable[str] | None = None,
+        dag_id: str | None = None,
+        key: str = BaseXCom.XCOM_RETURN_KEY,
+        include_prior_dates: bool = False,
+        *,
+        map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET,
+        default: Any = None,
+        run_id: str | None = None,
+    ) -> Any:
+        """
+        Pull XComs either from the API server (BaseXCom) or from the custom 
XCOM backend if configured.
+
+        The pull can be filtered optionally by certain criterion.
+
+        :param key: A key for the XCom. If provided, only XComs with matching
+            keys will be returned. The default key is ``'return_value'``, also
+            available as constant ``XCOM_RETURN_KEY``. This key is 
automatically
+            given to XComs returned by tasks (as opposed to being pushed
+            manually).
+        :param task_ids: Only XComs from tasks with matching ids will be
+            pulled. If *None* (default), the task_id of the calling task is 
used.
+        :param dag_id: If provided, only pulls XComs from this Dag. If *None*
+            (default), the Dag of the calling task is used.
+        :param map_indexes: If provided, only pull XComs with matching indexes.
+            If *None* (default), this is inferred from the task(s) being pulled
+            (see below for details).
+        :param include_prior_dates: If False, only XComs from the current
+            logical_date are returned. If *True*, XComs from previous dates
+            are returned as well.
+        :param run_id: If provided, only pulls XComs from a DagRun w/a 
matching run_id.
+            If *None* (default), the run_id of the calling task is used.
+
+        When pulling one single task (``task_id`` is *None* or a str) without
+        specifying ``map_indexes``, the return value is a single XCom entry
+        (map_indexes is set to map_index of the calling task instance).
+
+        When pulling task is mapped the specified ``map_index`` is used, so by 
default
+        pulling on mapped task will result in no matching XComs if the task 
instance
+        of the method call is not mapped. Otherwise, the map_index of the 
calling task
+        instance is used. Setting ``map_indexes`` to *None* will pull XCom as 
it would
+        from a non mapped task.
+
+        In either case, ``default`` (*None* if not specified) is returned if no
+        matching XComs are found.
+
+        When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` 
is
+        a non-str iterable), a list of matching XComs is returned. Elements in
+        the list is ordered by item ordering in ``task_id`` and ``map_index``.
+        """
+        if dag_id is None:
+            dag_id = self.dag_id
+        if run_id is None:
+            run_id = self.run_id
+
+        single_task_requested = isinstance(task_ids, (str, type(None)))
+        single_map_index_requested = isinstance(map_indexes, (int, type(None)))
+
+        if task_ids is None:
+            # default to the current task if not provided
+            task_ids = [self.task_id]
+        elif isinstance(task_ids, str):
+            task_ids = [task_ids]
+
+        # If map_indexes is not specified, pull xcoms from all map indexes for 
each task
+        if not is_arg_set(map_indexes):
+            xcoms: list[Any] = []
+            for t_id in task_ids:
+                values = await XCom.aget_all(
+                    run_id=run_id,
+                    key=key,
+                    task_id=t_id,
+                    dag_id=dag_id,
+                    include_prior_dates=include_prior_dates,
+                )
+
+                if values is None:
+                    xcoms.append(None)
+                else:
+                    xcoms.extend(values)
+            # For single task pulling from unmapped task, return single value
+            if single_task_requested and len(xcoms) == 1:
+                return xcoms[0]
+            return xcoms
+
+        # Original logic when map_indexes is explicitly specified
+        map_indexes_iterable: Iterable[int | None] = []
+        if isinstance(map_indexes, int) or map_indexes is None:
+            map_indexes_iterable = [map_indexes]
+        elif isinstance(map_indexes, Iterable):
+            map_indexes_iterable = map_indexes
+        else:
+            raise TypeError(
+                f"Invalid type for map_indexes: expected int, iterable of 
ints, or None, got {type(map_indexes)}"
+            )

Review Comment:
   Already tried that, not easy to achieve, but can have another look.  Problem 
is it doesn't make the code more readable, but yes this is not DRY I'm aware of 
it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to