dimberman commented on a change in pull request #14761:
URL: https://github.com/apache/airflow/pull/14761#discussion_r603538158
##########
File path: airflow/decorators/base.py
##########
@@ -63,73 +120,49 @@ def __init__(
op_args: Tuple[Any],
op_kwargs: Dict[str, Any],
multiple_outputs: bool = False,
+ kwargs_to_upstream: dict = None,
**kwargs,
) -> None:
- kwargs['task_id'] = self._get_unique_task_id(task_id,
kwargs.get('dag'), kwargs.get('task_group'))
- super().__init__(**kwargs)
+ kwargs['task_id'] = get_unique_task_id(task_id, kwargs.get('dag'),
kwargs.get('task_group'))
self.python_callable = python_callable
+ kwargs_to_upstream = kwargs_to_upstream or {}
# Check that arguments can be binded
signature(python_callable).bind(*op_args, **op_kwargs)
self.multiple_outputs = multiple_outputs
self.op_args = op_args
self.op_kwargs = op_kwargs
+ super().__init__(**kwargs_to_upstream, **kwargs)
- @staticmethod
- def _get_unique_task_id(
- task_id: str, dag: Optional[DAG] = None, task_group:
Optional[TaskGroup] = None
- ) -> str:
- """
- Generate unique task id given a DAG (or if run in a DAG context)
- Ids are generated by appending a unique number to the end of
- the original task id.
-
- Example:
- task_id
- task_id__1
- task_id__2
- ...
- task_id__20
- """
- dag = dag or DagContext.get_current_dag()
- if not dag:
- return task_id
-
- # We need to check if we are in the context of TaskGroup as the
task_id may
- # already be altered
- task_group = task_group or TaskGroupContext.get_current_task_group(dag)
- tg_task_id = task_group.child_id(task_id) if task_group else task_id
-
- if tg_task_id not in dag.task_ids:
- return task_id
- core = re.split(r'__\d+$', task_id)[0]
- suffixes = sorted(
- [
- int(re.split(r'^.+__', task_id)[1])
- for task_id in dag.task_ids
- if re.match(rf'^{core}__\d+$', task_id)
- ]
- )
- if not suffixes:
- return f'{core}__1'
- return f'{core}__{suffixes[-1] + 1}'
-
- @staticmethod
- def validate_python_callable(python_callable):
- """
- Validate that python callable can be wrapped by operator.
- Raises exception if invalid.
+ def execute(self, context: Dict):
+ return_value = super().execute(context)
+ self._handle_output(return_value=return_value, context=context,
xcom_push=self.xcom_push)
+ return return_value
- :param python_callable: Python object to be validated
- :raises: TypeError, AirflowException
+ def _handle_output(self, return_value: Any, context: Dict, xcom_push:
Callable):
"""
- if not callable(python_callable):
- raise TypeError('`python_callable` param must be callable')
- if 'self' in signature(python_callable).parameters.keys():
- raise AirflowException('@task does not support methods')
+ Handles logic for whether a decorator needs to push a single return
value or multiple return values.
- def execute(self, context: Dict):
- raise NotImplementedError()
+ :param return_value:
+ :param context:
+ :param xcom_push:
+ """
+ if not self.multiple_outputs:
+ return return_value
+ if isinstance(return_value, dict):
+ for key in return_value.keys():
+ if not isinstance(key, str):
+ raise AirflowException(
+ 'Returned dictionary keys must be strings when using '
+ f'multiple_outputs, found {key} ({type(key)}) instead'
+ )
Review comment:
@turbaszek isn't this part of the taskflow API docs?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]