ashb commented on code in PR #25176:
URL: https://github.com/apache/airflow/pull/25176#discussion_r932026131


##########
airflow/models/xcom_arg.py:
##########
@@ -379,13 +388,96 @@ def get_task_map_length(self, run_id: str, *, session: 
"Session") -> Optional[in
     @provide_session
     def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> 
Any:
         value = self.arg.resolve(context, session=session)
-        assert isinstance(value, (Sequence, dict))  # Validation was done when 
XCom was pushed.
+        if not isinstance(value, (Sequence, dict)):
+            raise ValueError(f"XCom map expects sequence or dict, not 
{type(value).__name__}")
         return _MapResult(value, self.callables)
 
 
+class _ZipResult(Sequence):
+    def __init__(self, values: Sequence[Union[Sequence, dict]], *, fillvalue: 
Any = NOTSET) -> None:
+        self.values = values
+        self.fillvalue = fillvalue
+
+    @staticmethod
+    def _get_or_fill(container: Union[Sequence, dict], index: Any, fillvalue: 
Any) -> Any:
+        try:
+            return container[index]
+        except (IndexError, KeyError):
+            return fillvalue
+
+    def __getitem__(self, index: Any) -> Any:
+        if index >= len(self):
+            raise IndexError(index)
+        return tuple(self._get_or_fill(value, index, self.fillvalue) for value 
in self.values)
+
+    def __len__(self) -> int:
+        lengths = (len(v) for v in self.values)
+        if self.fillvalue is NOTSET:
+            return min(lengths)
+        return max(lengths)
+
+
+class ZipXComArg(XComArg):
+    """An XCom reference with ``zip()`` applied.
+
+    This is constructed from multiple XComArg instances, and presents an
+    iterable that "zips" them together like the built-in ``zip()`` (and
+    ``itertools.zip_longest()`` if ``fillvalue`` is provided).
+    """
+
+    def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> 
None:
+        if not args:
+            raise ValueError("At least one input is required")
+        self.args = args
+        self.fillvalue = fillvalue
+
+    def __repr__(self) -> str:
+        args_iter = iter(self.args)
+        first = repr(next(args_iter))
+        rest = ", ".join(repr(arg) for arg in args_iter)
+        if self.fillvalue is NOTSET:
+            return f"{first}.zip({rest})"
+        return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
+
+    def _serialize(self) -> Dict[str, Any]:
+        args = [serialize_xcom_arg(arg) for arg in self.args]
+        if self.fillvalue is NOTSET:
+            return {"args": args}
+        return {"args": args, "fillvalue": self.fillvalue}
+
+    @classmethod
+    def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg:
+        return cls(
+            [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
+            fillvalue=data.get("fillvalue", NOTSET),
+        )
+
+    def iter_references(self) -> Iterator[Tuple["Operator", str]]:
+        for arg in self.args:
+            yield from arg.iter_references()
+
+    def get_task_map_length(self, run_id: str, *, session: "Session") -> 
Optional[int]:
+        all_lengths = (arg.get_task_map_length(run_id, session=session) for 
arg in self.args)
+        ready_lengths = [length for length in all_lengths if length is not 
None]
+        if len(ready_lengths) != len(self.args):
+            return None  # If any of the referenced XComs is not ready, we are 
not ready either.

Review Comment:
   I'm not sure this is the right behavoiur when fillvalue is provided, 
espeically given things like https://github.com/apache/airflow/issues/24338)



##########
airflow/serialization/serialized_objects.py:
##########
@@ -393,7 +392,7 @@ def _serialize(cls, var: Any) -> Any:  # Unfortunately 
there is no support for r
         elif isinstance(var, Param):
             return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
         elif isinstance(var, XComArg):
-            return cls._encode(cls._serialize_xcomarg(var), type_=DAT.XCOM_REF)
+            return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)

Review Comment:
   This leads to a slightly odd serialization with type "doubled":
   
   ```json
   { "_type": "xcom_ref", "_val": { "type": "", ... }}
   ```
   
   Is this the right thing to do?



##########
airflow/models/xcom_arg.py:
##########
@@ -285,8 +288,13 @@ def iter_references(self) -> Iterator[Tuple["Operator", 
str]]:
 
     def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
         if self.key != XCOM_RETURN_KEY:
-            raise ValueError
-        return MapXComArg(self, [f])
+            raise ValueError("cannot map against non-return XCom")
+        return super().map(f)
+
+    def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg":
+        if self.key != XCOM_RETURN_KEY:
+            raise ValueError("cannot map against non-return XCom")

Review Comment:
   What is the reason for this limitation btw? (I know we have had it on map 
for a while, but I can't think of anything that would actually break if we 
didn't have it)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to