casassg commented on a change in pull request #8962: URL: https://github.com/apache/airflow/pull/8962#discussion_r431452493
########## File path: airflow/operators/python.py ########## @@ -145,6 +147,141 @@ def execute_callable(self): return self.python_callable(*self.op_args, **self.op_kwargs) +class _PythonFunctionalOperator(BaseOperator): + """ + Wraps a Python callable and captures args/kwargs when called for execution. + + :param python_callable: A reference to an object that is callable + :type python_callable: python callable + :param multiple_outputs: if set, function return value will be + unrolled to multiple XCom values. List/Tuples will unroll to xcom values + with index as key. Dict will unroll to xcom values with keys as keys. + Defaults to False. + :type multiple_outputs: bool + """ + + template_fields = ('_op_args', '_op_kwargs') + ui_color = '#ffefeb' + + # since we won't mutate the arguments, we should just do the shallow copy + # there are some cases we can't deepcopy the objects(e.g protobuf). + shallow_copy_attrs = ('python_callable',) + + @apply_defaults + def __init__( + self, + python_callable: Callable, + multiple_outputs: bool = False, + *args, + **kwargs + ) -> None: + # Check if we need to generate a new task_id + task_id = kwargs.get('task_id', None) + dag = kwargs.get('dag', None) or DagContext.get_current_dag() + if task_id and dag and task_id in dag.task_ids: + prefix = task_id.rsplit("__", 1)[0] + task_id = sorted( + filter(lambda x: x.startswith(prefix), dag.task_ids), + reverse=True + )[0] + num = int(task_id[-1] if '__' in task_id else '0') + 1 + kwargs['task_id'] = f'{prefix}__{num}' + + if not kwargs.get('do_xcom_push', True) and not multiple_outputs: + raise AirflowException('@task needs to have either do_xcom_push=True or ' + 'multiple_outputs=True.') + if not callable(python_callable): + raise AirflowException('`python_callable` param must be callable') + self._fail_if_method(python_callable) + super().__init__(*args, **kwargs) + self.python_callable = python_callable + self.multiple_outputs = multiple_outputs + self._kwargs = kwargs + self._op_args: List[Any] = [] + self._called = False + self._op_kwargs: Dict[str, Any] = {} + + @staticmethod + def _fail_if_method(python_callable): + if 'self' in signature(python_callable).parameters.keys(): + raise AirflowException('@task does not support methods') + + def __call__(self, *args, **kwargs): + # If args/kwargs are set, then operator has been called. Raise exception + if self._called: Review comment: Main worry is that then what is `update_user`. What you are describing here is using `update_user` as an operator factory. It has it's value, but it also feels too magic to me atm. If `update_user` is a factory, then you can't change the operator instance at all or use it to set non-data dependencies. We could capture task_id kwarg and generate a new operator, but then what is `update_user` the first operator, the latest one? What does `update_user` represent? You can either do (1) `update_user(i) for i in range(20)` or (2) `update_user >> other_operation`, but not both. I prefer to support 2nd option as it adapts more to what Airflow already does with operators. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org