ashb commented on a change in pull request #19965:
URL: https://github.com/apache/airflow/pull/19965#discussion_r771512469
##########
File path: airflow/models/baseoperator.py
##########
@@ -1659,6 +1629,115 @@ def defer(
"""
raise TaskDeferred(trigger=trigger, method_name=method_name,
kwargs=kwargs, timeout=timeout)
+ def map(self, **kwargs) -> "MappedOperator":
+ return MappedOperator(
+ operator_class=type(self),
+ operator=self,
+ task_id=self.task_id,
+ task_group=getattr(self, 'task_group', None),
+ dag=getattr(self, '_dag', None),
+ start_date=self.start_date,
+ end_date=self.end_date,
+ partial_kwargs=self.__init_kwargs,
+ mapped_kwargs=kwargs,
+ )
+
+
+def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str,
value: Dict[str, Any]):
+ if isinstance(str, cls):
+ # Serialized version -- would have been validated at parse time
+ return
+
+ # use a dict so order of args is same as code order
+ unknown_args = value.copy()
+ for clazz in cls.mro():
+ # Mypy doesn't like doing `clas.__init__`, Error is: Cannot access
"__init__" directly
+ init = clazz.__init__ # type: ignore
+
+ if not hasattr(init, '_BaseOperatorMeta__param_names'):
+ continue
+
+ for name in init._BaseOperatorMeta__param_names:
+ unknown_args.pop(name, None)
+
+ if not unknown_args:
+ # If we have no args left ot check: stop looking at the MRO chian
+ return
+
+ if len(unknown_args) == 1:
+ raise TypeError(
+ f'{cls.__name__}.{func_name} got unexpected keyword argument
{unknown_args.popitem()[0]!r}'
+ )
+ else:
+ names = ", ".join(repr(n) for n in unknown_args)
+ raise TypeError(f'{cls.__name__}.{func_name} got unexpected keyword
arguments {names}')
+
+
[email protected](kw_only=True)
+class MappedOperator(DAGNode):
+ """Object representing a mapped operator in a DAG"""
+
+ operator_class: Type[BaseOperator] = attr.ib(repr=lambda c: c.__name__)
+ task_id: str
+ partial_kwargs: Dict[str, Any]
+ mapped_kwargs: Dict[str, Any] = attr.ib(
+ validator=lambda self, _, v:
_validate_kwarg_names_for_mapping(self.operator_class, "map", v)
+ )
+ operator: Optional[BaseOperator] = None
+ dag: Optional["DAG"] = None
+ upstream_task_ids: Set[str] = attr.ib(factory=set, repr=False)
+ downstream_task_ids: Set[str] = attr.ib(factory=set, repr=False)
+
+ task_group: Optional["TaskGroup"] = attr.ib(repr=False)
+ # BaseOperator-like interface -- needed so we can add oursleves to the
dag.tasks
+ start_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None)
+ end_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None)
+
+ def __attrs_post_init__(self):
+ if self.dag and self.operator:
+ # When BaseOperator() was called with a DAG, it would have been
added straight away, but now we
+ # are mapped, we want to _remove_ that task (`self.operator`) from
the dag
+ self.dag._remove_task(self.task_id)
Review comment:
It's done inside this function.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]