uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809721162



##########
File path: airflow/decorators/base.py
##########
@@ -369,31 +371,43 @@ class DecoratedMappedOperator(MappedOperator):
     multiple_outputs: bool
     python_callable: Callable
 
-    # We can't save these in partial_kwargs because op_args and op_kwargs need
-    # to be present in mapped_kwargs, and MappedOperator prevents duplication.
-    partial_op_kwargs: Dict[str, Any]
+    # We can't save these in mapped_kwargs because op_kwargs need to be present
+    # in partial_kwargs, and MappedOperator prevents duplication.
+    mapped_op_kwargs: Dict[str, "MapArgument"]
 
     @classmethod
     @cache
     def get_serialized_fields(cls):
-        # The magic argument-less super() does not work well with @cache
-        # (actually lru_cache in general), so we use the explicit form instead.
+        # The magic super() doesn't work here, so we use the explicit form.
+        # Not using super(..., cls) to work around pyupgrade bug.
         sup = super(DecoratedMappedOperator, DecoratedMappedOperator)
-        return sup.get_serialized_fields() | {"partial_op_kwargs"}
+        return sup.get_serialized_fields() | {"mapped_op_kwargs"}
 
-    def _create_unmapped_operator(
-        self,
-        *,
-        mapped_kwargs: Dict[str, Any],
-        partial_kwargs: Dict[str, Any],
-        real: bool,
-    ) -> "BaseOperator":
+    def __attrs_post_init__(self):
+        # The magic super() doesn't work here, so we use the explicit form.
+        # Not using super(..., self) to work around pyupgrade bug.
+        super(DecoratedMappedOperator, 
DecoratedMappedOperator).__attrs_post_init__(self)
+        XComArg.apply_upstream_relationship(self, self.mapped_op_kwargs)
+
+    def _get_expansion_kwargs(self) -> Dict[str, "MapArgument"]:
+        """The kwargs to calculate expansion length against.
+
+        Different from classic operators, a decorated (taskflow) operator's
+        ``map()`` contributes to the ``op_kwargs`` operator argument (not the
+        operator arguments themselves), and should therefore expand against it.
+        """
+        return self.mapped_op_kwargs
+
+    def _create_unmapped_operator(self, *, mapped_kwargs: Dict[str, Any], 
real: bool) -> "BaseOperator":
         assert not isinstance(self.operator_class, str)
-        mapped_kwargs = mapped_kwargs.copy()
-        del mapped_kwargs["op_kwargs"]
+        partial_kwargs = self.partial_kwargs.copy()
+        if real:
+            mapped_op_kwargs: Dict[str, Any] = self.mapped_op_kwargs
+        else:
+            mapped_op_kwargs = {k: unittest.mock.MagicMock(name=k) for k in 
self.mapped_op_kwargs}

Review comment:
       I extracted this into a function with a three-paragraph docstring.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to