turbaszek commented on a change in pull request #8962: URL: https://github.com/apache/airflow/pull/8962#discussion_r429151128
########## File path: airflow/operators/python.py ########## @@ -145,6 +147,141 @@ def execute_callable(self): return self.python_callable(*self.op_args, **self.op_kwargs) +class _PythonFunctionalOperator(BaseOperator): + """ + Wraps a Python callable and captures args/kwargs when called for execution. + + :param python_callable: A reference to an object that is callable + :type python_callable: python callable + :param multiple_outputs: if set, function return value will be + unrolled to multiple XCom values. List/Tuples will unroll to xcom values + with index as key. Dict will unroll to xcom values with keys as keys. + Defaults to False. + :type multiple_outputs: bool + """ + + template_fields = ('_op_args', '_op_kwargs') + ui_color = '#ffefeb' + + # since we won't mutate the arguments, we should just do the shallow copy + # there are some cases we can't deepcopy the objects(e.g protobuf). + shallow_copy_attrs = ('python_callable',) + + @apply_defaults + def __init__( + self, + python_callable: Callable, + multiple_outputs: bool = False, + *args, + **kwargs + ) -> None: + # Check if we need to generate a new task_id + task_id = kwargs.get('task_id', None) + dag = kwargs.get('dag', None) or DagContext.get_current_dag() + if task_id and dag and task_id in dag.task_ids: + prefix = task_id.rsplit("__", 1)[0] + task_id = sorted( + filter(lambda x: x.startswith(prefix), dag.task_ids), + reverse=True + )[0] + num = int(task_id[-1] if '__' in task_id else '0') + 1 + kwargs['task_id'] = f'{prefix}__{num}' + + if not kwargs.get('do_xcom_push', True) and not multiple_outputs: + raise AirflowException('@task needs to have either do_xcom_push=True or ' + 'multiple_outputs=True.') + if not callable(python_callable): + raise AirflowException('`python_callable` param must be callable') + self._fail_if_method(python_callable) + super().__init__(*args, **kwargs) + self.python_callable = python_callable + self.multiple_outputs = multiple_outputs + self._kwargs = kwargs + self._op_args: List[Any] = [] + self._called = False + self._op_kwargs: Dict[str, Any] = {} + + @staticmethod + def _fail_if_method(python_callable): + if 'self' in signature(python_callable).parameters.keys(): + raise AirflowException('@task does not support methods') + + def __call__(self, *args, **kwargs): + # If args/kwargs are set, then operator has been called. Raise exception + if self._called: Review comment: Do I correctly understand that this will not work? ``` python @task def update_user(user_id: str): ... with DAG(...): # Fetch list of users ... # Execute task for each user for user_id in users_list: update_user(user_id) ``` ########## File path: airflow/operators/python.py ########## @@ -145,6 +148,142 @@ def execute_callable(self): return self.python_callable(*self.op_args, **self.op_kwargs) +class _PythonFunctionalOperator(BaseOperator): + """ + Wraps a Python callable and captures args/kwargs when called for execution. + + :param python_callable: A reference to an object that is callable + :type python_callable: python callable + :param multiple_outputs: if set, function return value will be + unrolled to multiple XCom values. List/Tuples will unroll to xcom values + with index as key. Dict will unroll to xcom values with keys as keys. + Defaults to False. + :type multiple_outputs: bool + """ + + template_fields = ('_op_args', '_op_kwargs') + ui_color = '#ffefeb' + + # since we won't mutate the arguments, we should just do the shallow copy + # there are some cases we can't deepcopy the objects(e.g protobuf). + shallow_copy_attrs = ('python_callable',) + + @apply_defaults + def __init__( + self, + python_callable: Callable, + multiple_outputs: bool = False, + *args, + **kwargs + ) -> None: + dag = kwargs.get('dag', None) or DagContext.get_current_dag() + kwargs['task_id'] = self._get_unique_task_id(kwargs['task_id'], dag) + self._validate_python_callable(python_callable) + super().__init__(*args, **kwargs) + self.python_callable = python_callable + self.multiple_outputs = multiple_outputs + self._kwargs = kwargs + self._op_args: List[Any] = [] + self._called = False + self._op_kwargs: Dict[str, Any] = {} + + @staticmethod + def _get_unique_task_id(task_id, dag): + if not dag or task_id not in dag.task_ids: + return task_id + core = re.split(r'__\d+$', task_id)[0] + suffixes = sorted( + [int(re.split(r'^.+__', task_id)[1]) + for task_id in dag.task_ids + if re.match(rf'^{core}__\d+$', task_id)] + ) + if not suffixes: + return f'{core}__1' + return f'{core}__{suffixes[-1] + 1}' + + @staticmethod + def _validate_python_callable(python_callable): + if not callable(python_callable): + raise AirflowException('`python_callable` param must be callable') + if 'self' in signature(python_callable).parameters.keys(): + raise AirflowException('@task does not support methods') + + def __call__(self, *args, **kwargs): + # If args/kwargs are set, then operator has been called. Raise exception + if self._called: + raise AirflowException('@task decorated functions can only be called once. If you need to reuse ' + 'it several times in a DAG, use the `copy` method.') + + # If we have no DAG, reinitialize class to capture DAGContext and DAG default args. + if not self.has_dag(): + self.__init__(python_callable=self.python_callable, + multiple_outputs=self.multiple_outputs, + **self._kwargs) + + # Capture args/kwargs + self._op_args = args + self._op_kwargs = kwargs + self._called = True + return XComArg(self) + + def copy(self, task_id: Optional[str] = None, **kwargs): + """ + Create a copy of the task, allow to overwrite ctor kwargs if needed. + + If alias is created a new DAGContext, apply defaults and set new DAG as the operator DAG. + + :param task_id: Task id for the new operator + :type task_id: Optional[str] + """ + if task_id: + self._kwargs['task_id'] = task_id + return _PythonFunctionalOperator( + python_callable=self.python_callable, + multiple_outputs=self.multiple_outputs, + **{**kwargs, **self._kwargs} + ) + + def execute(self, context: Dict): + return_value = self.python_callable(*self._op_args, **self._op_kwargs) + self.log.info("Done. Returned value was: %s", return_value) + if not self.multiple_outputs: + return return_value + if isinstance(return_value, dict): + for key, value in return_value.items(): + self.xcom_push(context, str(key), value) + elif isinstance(return_value, (list, tuple)): + for key, value in enumerate(return_value): + self.xcom_push(context, str(key), value) + return return_value + + +def task(python_callable: Optional[Callable] = None, **kwargs): + """ + Python operator decorator. Wraps a function into an Airflow operator. + Accepts kwargs for operator kwarg. Will try to wrap operator into DAG at declaration or + on function invocation. Use alias to reuse function in the DAG. + + :param python_callable: Function to decorate + :type python_callable: Optional[Callable] + :param multiple_outputs: if set, function return value will be + unrolled to multiple XCom values. List/Tuples will unroll to xcom values + with index as key. Dict will unroll to xcom values with keys as keys. + Defaults to False. + :type multiple_outputs: bool + + """ + def wrapper(f): + """Python wrapper to generate PythonFunctionalOperator out of simple python functions. Review comment: ```suggestion """ Python wrapper to generate PythonFunctionalOperator out of simple python functions. ``` ########## File path: airflow/operators/python.py ########## @@ -145,6 +147,141 @@ def execute_callable(self): return self.python_callable(*self.op_args, **self.op_kwargs) +class _PythonFunctionalOperator(BaseOperator): + """ + Wraps a Python callable and captures args/kwargs when called for execution. + + :param python_callable: A reference to an object that is callable + :type python_callable: python callable + :param multiple_outputs: if set, function return value will be + unrolled to multiple XCom values. List/Tuples will unroll to xcom values + with index as key. Dict will unroll to xcom values with keys as keys. + Defaults to False. + :type multiple_outputs: bool + """ + + template_fields = ('_op_args', '_op_kwargs') + ui_color = '#ffefeb' + + # since we won't mutate the arguments, we should just do the shallow copy + # there are some cases we can't deepcopy the objects(e.g protobuf). + shallow_copy_attrs = ('python_callable',) + + @apply_defaults + def __init__( + self, + python_callable: Callable, + multiple_outputs: bool = False, + *args, + **kwargs + ) -> None: + # Check if we need to generate a new task_id + task_id = kwargs.get('task_id', None) + dag = kwargs.get('dag', None) or DagContext.get_current_dag() + if task_id and dag and task_id in dag.task_ids: + prefix = task_id.rsplit("__", 1)[0] + task_id = sorted( + filter(lambda x: x.startswith(prefix), dag.task_ids), + reverse=True + )[0] + num = int(task_id[-1] if '__' in task_id else '0') + 1 + kwargs['task_id'] = f'{prefix}__{num}' + + if not kwargs.get('do_xcom_push', True) and not multiple_outputs: + raise AirflowException('@task needs to have either do_xcom_push=True or ' + 'multiple_outputs=True.') + if not callable(python_callable): + raise AirflowException('`python_callable` param must be callable') + self._fail_if_method(python_callable) + super().__init__(*args, **kwargs) + self.python_callable = python_callable + self.multiple_outputs = multiple_outputs + self._kwargs = kwargs + self._op_args: List[Any] = [] + self._called = False + self._op_kwargs: Dict[str, Any] = {} + + @staticmethod + def _fail_if_method(python_callable): + if 'self' in signature(python_callable).parameters.keys(): + raise AirflowException('@task does not support methods') + + def __call__(self, *args, **kwargs): + # If args/kwargs are set, then operator has been called. Raise exception + if self._called: + raise AirflowException('@task decorated functions can only be called once. If you need to reuse ' + 'it several times in a DAG, use the `copy` method.') + + # If we have no DAG, reinitialize class to capture DAGContext and DAG default args. + if not self.has_dag(): + self.__init__(python_callable=self.python_callable, + multiple_outputs=self.multiple_outputs, + **self._kwargs) + + # Capture args/kwargs + self._op_args = args + self._op_kwargs = kwargs + self._called = True + return XComArg(self) + + def copy(self, task_id: Optional[str] = None, **kwargs): + """ + Create a copy of the task, allow to overwrite ctor kwargs if needed. + + If alias is created a new DAGContext, apply defaults and set new DAG as the operator DAG. + + :param task_id: Task id for the new operator + :type task_id: Optional[str] + """ + if task_id: + self._kwargs['task_id'] = task_id + return _PythonFunctionalOperator( + python_callable=self.python_callable, + multiple_outputs=self.multiple_outputs, + **{**kwargs, **self._kwargs} + ) + + def execute(self, context: Dict): + return_value = self.python_callable(*self._op_args, **self._op_kwargs) + self.log.info("Done. Returned value was: %s", return_value) + if not self.multiple_outputs: + return return_value + if isinstance(return_value, dict): + for key, value in return_value.items(): + self.xcom_push(context, str(key), value) + elif isinstance(return_value, (list, tuple)): + for key, value in enumerate(return_value): + self.xcom_push(context, str(key), value) + return return_value + + +def task(python_callable: Optional[Callable] = None, **kwargs): + """ + Python operator decorator. Wraps a function into an Airflow operator. + Accepts kwargs for operator kwarg. Will try to wrap operator into DAG at declaration or + on function invocation. Use alias to reuse function in the DAG. + + :param python_callable: Function to decorate + :type python_callable: Optional[Callable] + :param multiple_outputs: if set, function return value will be Review comment: Personally I would prefer to use typehints to indicate multiple output than a flag. It will solve the issue and add more information to task definitions. Of course, typehints are optional but we can require them to make multiple outputs work. Here's a similar thing from PySpark: https://databricks.com/blog/2020/05/20/new-pandas-udfs-and-python-type-hints-in-the-upcoming-release-of-apache-spark-3-0.html ########## File path: airflow/operators/python.py ########## @@ -145,6 +148,142 @@ def execute_callable(self): return self.python_callable(*self.op_args, **self.op_kwargs) +class _PythonFunctionalOperator(BaseOperator): + """ + Wraps a Python callable and captures args/kwargs when called for execution. + + :param python_callable: A reference to an object that is callable + :type python_callable: python callable + :param multiple_outputs: if set, function return value will be + unrolled to multiple XCom values. List/Tuples will unroll to xcom values + with index as key. Dict will unroll to xcom values with keys as keys. + Defaults to False. + :type multiple_outputs: bool + """ + + template_fields = ('_op_args', '_op_kwargs') + ui_color = '#ffefeb' + + # since we won't mutate the arguments, we should just do the shallow copy + # there are some cases we can't deepcopy the objects(e.g protobuf). + shallow_copy_attrs = ('python_callable',) + + @apply_defaults + def __init__( + self, + python_callable: Callable, + multiple_outputs: bool = False, + *args, + **kwargs + ) -> None: + dag = kwargs.get('dag', None) or DagContext.get_current_dag() + kwargs['task_id'] = self._get_unique_task_id(kwargs['task_id'], dag) + self._validate_python_callable(python_callable) + super().__init__(*args, **kwargs) + self.python_callable = python_callable + self.multiple_outputs = multiple_outputs + self._kwargs = kwargs + self._op_args: List[Any] = [] + self._called = False + self._op_kwargs: Dict[str, Any] = {} + + @staticmethod + def _get_unique_task_id(task_id, dag): Review comment: ```suggestion def _get_unique_task_id(task_id: str, dag: DAG) -> str: ``` ########## File path: airflow/ti_deps/deps/trigger_rule_dep.py ########## @@ -18,10 +18,10 @@ from collections import Counter -import airflow from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.session import provide_session from airflow.utils.state import State +from airflow.utils.trigger_rule import TriggerRule as TR Review comment: Is this a related change? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org