This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 23faab5 [AIRFLOW-8057] [AIP-31] Add @task decorator (#8962)
23faab5 is described below
commit 23faab59ebd8a9d1ad8dc7ff7169e06346ed9c24
Author: Gerard Casas Saez <[email protected]>
AuthorDate: Tue Jun 23 15:58:30 2020 -0600
[AIRFLOW-8057] [AIP-31] Add @task decorator (#8962)
Closes #8057. Closes #8056.
---
airflow/decorators.py | 18 ++
airflow/models/dag.py | 5 +
airflow/operators/python.py | 145 +++++++++++++++-
airflow/ti_deps/deps/trigger_rule_dep.py | 5 +-
docs/concepts.rst | 96 +++++++++++
tests/operators/test_python.py | 281 +++++++++++++++++++++++++++++++
6 files changed, 545 insertions(+), 5 deletions(-)
diff --git a/airflow/decorators.py b/airflow/decorators.py
new file mode 100644
index 0000000..eb593af
--- /dev/null
+++ b/airflow/decorators.py
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.operators.python import task # noqa # pylint:
disable=unused-import
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 174dfa5..ccdb60a 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1318,6 +1318,11 @@ class DAG(BaseDag, LoggingMixin):
for t in self.roots:
get_downstream(t)
+ @property
+ def task(self):
+ from airflow.operators.python import task
+ return functools.partial(task, dag=self)
+
def add_task(self, task):
"""
Add a task to the DAG
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 0fac9f0..7107e17 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -16,21 +16,25 @@
# specific language governing permissions and limitations
# under the License.
+import functools
import inspect
import os
import pickle
+import re
import sys
import types
from inspect import signature
from itertools import islice
from tempfile import TemporaryDirectory
from textwrap import dedent
-from typing import Callable, Dict, Iterable, List, Optional
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import dill
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, SkipMixin
+from airflow.models.dag import DAG, DagContext
+from airflow.models.xcom_arg import XComArg
from airflow.utils.decorators import apply_defaults
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv
@@ -145,6 +149,145 @@ class PythonOperator(BaseOperator):
return self.python_callable(*self.op_args, **self.op_kwargs)
+class _PythonFunctionalOperator(BaseOperator):
+ """
+ Wraps a Python callable and captures args/kwargs when called for execution.
+
+ :param python_callable: A reference to an object that is callable
+ :type python_callable: python callable
+ :param op_kwargs: a dictionary of keyword arguments that will get unpacked
+ in your function (templated)
+ :type op_kwargs: dict
+ :param op_args: a list of positional arguments that will get unpacked when
+ calling your callable (templated)
+ :type op_args: list
+ :param multiple_outputs: if set, function return value will be
+ unrolled to multiple XCom values. Dict will unroll to xcom values with
keys as keys.
+ Defaults to False.
+ :type multiple_outputs: bool
+ """
+
+ template_fields = ('op_args', 'op_kwargs')
+ ui_color = PythonOperator.ui_color
+
+ # since we won't mutate the arguments, we should just do the shallow copy
+ # there are some cases we can't deepcopy the objects (e.g protobuf).
+ shallow_copy_attrs = ('python_callable',)
+
+ @apply_defaults
+ def __init__(
+ self,
+ python_callable: Callable,
+ task_id: str,
+ op_args: Tuple[Any],
+ op_kwargs: Dict[str, Any],
+ multiple_outputs: bool = False,
+ **kwargs
+ ) -> None:
+ kwargs['task_id'] = self._get_unique_task_id(task_id,
kwargs.get('dag', None))
+ super().__init__(**kwargs)
+ self.python_callable = python_callable
+
+ # Check that arguments can be binded
+ signature(python_callable).bind(*op_args, **op_kwargs)
+ self.multiple_outputs = multiple_outputs
+ self.op_args = op_args
+ self.op_kwargs = op_kwargs
+
+ @staticmethod
+ def _get_unique_task_id(task_id: str, dag: Optional[DAG] = None) -> str:
+ """
+ Generate unique task id given a DAG (or if run in a DAG context)
+ Ids are generated by appending a unique number to the end of
+ the original task id.
+
+ Example:
+ task_id
+ task_id__1
+ task_id__2
+ ...
+ task_id__20
+ """
+ dag = dag or DagContext.get_current_dag()
+ if not dag or task_id not in dag.task_ids:
+ return task_id
+ core = re.split(r'__\d+$', task_id)[0]
+ suffixes = sorted(
+ [int(re.split(r'^.+__', task_id)[1])
+ for task_id in dag.task_ids
+ if re.match(rf'^{core}__\d+$', task_id)]
+ )
+ if not suffixes:
+ return f'{core}__1'
+ return f'{core}__{suffixes[-1] + 1}'
+
+ @staticmethod
+ def validate_python_callable(python_callable):
+ """
+ Validate that python callable can be wrapped by operator.
+ Raises exception if invalid.
+
+ :param python_callable: Python object to be validated
+ :raises: TypeError, AirflowException
+ """
+ if not callable(python_callable):
+ raise TypeError('`python_callable` param must be callable')
+ if 'self' in signature(python_callable).parameters.keys():
+ raise AirflowException('@task does not support methods')
+
+ def execute(self, context: Dict):
+ return_value = self.python_callable(*self.op_args, **self.op_kwargs)
+ self.log.debug("Done. Returned value was: %s", return_value)
+ if not self.multiple_outputs:
+ return return_value
+ if isinstance(return_value, dict):
+ for key in return_value.keys():
+ if not isinstance(key, str):
+ raise AirflowException('Returned dictionary keys must be
strings when using '
+ f'multiple_outputs, found {key}
({type(key)}) instead')
+ for key, value in return_value.items():
+ self.xcom_push(context, key, value)
+ else:
+ raise AirflowException(f'Returned output was type
{type(return_value)} expected dictionary '
+ 'for multiple_outputs')
+ return return_value
+
+
+def task(python_callable: Optional[Callable] = None, multiple_outputs: bool =
False, **kwargs):
+ """
+ Python operator decorator. Wraps a function into an Airflow operator.
+ Accepts kwargs for operator kwarg. Can be reused in a single DAG.
+
+ :param python_callable: Function to decorate
+ :type python_callable: Optional[Callable]
+ :param multiple_outputs: if set, function return value will be
+ unrolled to multiple XCom values. List/Tuples will unroll to xcom
values
+ with index as key. Dict will unroll to xcom values with keys as XCom
keys.
+ Defaults to False.
+ :type multiple_outputs: bool
+
+ """
+ def wrapper(f):
+ """
+ Python wrapper to generate PythonFunctionalOperator out of simple
python functions.
+ Used for Airflow functional interface
+ """
+ _PythonFunctionalOperator.validate_python_callable(f)
+ kwargs.setdefault('task_id', f.__name__)
+
+ @functools.wraps(f)
+ def factory(*args, **f_kwargs):
+ op = _PythonFunctionalOperator(python_callable=f, op_args=args,
op_kwargs=f_kwargs,
+ multiple_outputs=multiple_outputs,
**kwargs)
+ return XComArg(op)
+ return factory
+ if callable(python_callable):
+ return wrapper(python_callable)
+ elif python_callable is not None:
+ raise AirflowException('No args allowed while using @task, use kwargs
instead')
+ return wrapper
+
+
class BranchPythonOperator(PythonOperator, SkipMixin):
"""
Allows a workflow to "branch" or follow a path following the execution
diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py
b/airflow/ti_deps/deps/trigger_rule_dep.py
index a816dbf..d9e914c 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -18,10 +18,10 @@
from collections import Counter
-import airflow
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.session import provide_session
from airflow.utils.state import State
+from airflow.utils.trigger_rule import TriggerRule as TR
class TriggerRuleDep(BaseTIDep):
@@ -50,7 +50,6 @@ class TriggerRuleDep(BaseTIDep):
@provide_session
def _get_dep_statuses(self, ti, session, dep_context):
- TR = airflow.utils.trigger_rule.TriggerRule
# Checking that all upstream dependencies have succeeded
if not ti.task.upstream_list:
yield self._passing_status(
@@ -111,8 +110,6 @@ class TriggerRuleDep(BaseTIDep):
:type session: sqlalchemy.orm.session.Session
"""
- TR = airflow.utils.trigger_rule.TriggerRule
-
task = ti.task
upstream = len(task.upstream_task_ids)
trigger_rule = task.trigger_rule
diff --git a/docs/concepts.rst b/docs/concepts.rst
index e4aca8d..4695a83 100644
--- a/docs/concepts.rst
+++ b/docs/concepts.rst
@@ -116,6 +116,46 @@ DAGs can be used as context managers to automatically
assign new operators to th
op.dag is dag # True
+.. _concepts:functional_dags:
+
+Functional DAGs
+---------------
+
+DAGs can be defined using functional abstractions. Outputs and inputs are sent
between tasks using
+:ref:`XCom values <concepts:xcom>`. In addition, you can wrap functions as
tasks using the
+:ref:`task decorator <concepts:task_decorator>`. Airflow will also
automatically add dependencies between
+tasks to ensure that XCom messages are available when operators are executed.
+
+Example DAG with functional abstraction
+
+.. code-block:: python
+
+ with DAG(
+ 'send_server_ip', default_args=default_args, schedule_interval=None
+ ) as dag:
+
+ # Using default connection as it's set to httpbin.org by default
+ get_ip = SimpleHttpOperator(
+ task_id='get_ip', endpoint='get', method='GET', xcom_push=True
+ )
+
+ @dag.task(multiple_outputs=True)
+ def prepare_email(raw_json: str) -> Dict[str, str]:
+ external_ip = json.loads(raw_json)['origin']
+ return {
+ 'subject':f'Server connected from {external_ip}',
+ 'body': f'Seems like today your server executing Airflow is connected
from the external IP {external_ip}<br>'
+ }
+
+ email_info = prepare_email(get_ip.output)
+
+ send_email = EmailOperator(
+ task_id='send_email',
+ to='[email protected]',
+ subject=email_info['subject'],
+ html_content=email_info['body']
+ )
+
.. _concepts:dagruns:
DAG Runs
@@ -173,6 +213,62 @@ Each task is a node in our DAG, and there is a dependency
from task_1 to task_2:
We can say that task_1 is *upstream* of task_2, and conversely task_2 is
*downstream* of task_1.
When a DAG Run is created, task_1 will start running and task_2 waits for
task_1 to complete successfully before it may start.
+.. _concepts:task_decorator:
+
+Python task decorator
+---------------------
+
+Airflow ``task`` decorator converts any Python function to an Airflow operator.
+The decorated function can be called once to set the arguments and key
arguments for operator execution.
+
+
+.. code-block:: python
+
+ with DAG('my_dag', start_date=datetime(2020, 5, 15)) as dag:
+ @dag.task
+ def hello_world():
+ print('hello world!')
+
+
+ # Also...
+ from airflow.decorators import task
+
+
+ @task
+ def hello_name(name: str):
+ print(f'hello {name}!')
+
+
+ hello_name('Airflow users')
+
+Task decorator captures returned values and sends them to the :ref:`XCom
backend <concepts:xcom>`. By default, returned
+value is saved as a single XCom value. You can set ``multiple_outputs`` key
argument to ``True`` to unroll dictionaries,
+lists or tuples into seprate XCom values. This can be used with regular
operators to create
+:ref:`functional DAGs <concepts:functional_dags>`.
+
+Calling a decorated function returns an ``XComArg`` instance. You can use it
to set templated fields on downstream
+operators.
+
+You can call a decorated function more than once in a DAG. The decorated
function will automatically generate
+a unique ``task_id`` for each generated operator.
+
+.. code-block:: python
+
+ with DAG('my_dag', start_date=datetime(2020, 5, 15)) as dag:
+
+ @dag.task
+ def update_user(user_id: int):
+ ...
+
+ # Avoid generating this list dynamically to keep DAG topology stable
between DAG runs
+ for user_id in user_ids:
+ update_user(user_id)
+
+ # This will generate an operator for each user_id
+
+Task ids are generated by appending a number at the end of the original task
id. For the above example, the DAG will have
+the following task ids: ``[update_user, update_user__1, update_user__2, ...
update_user__n]``.
+
Task Instances
==============
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 8a31dd8..83000e3 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -28,13 +28,16 @@ from subprocess import CalledProcessError
from typing import List
import funcsigs
+import pytest
from airflow.exceptions import AirflowException
from airflow.models import DAG, DagRun, TaskInstance as TI
from airflow.models.taskinstance import clear_task_instances
+from airflow.models.xcom_arg import XComArg
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python import (
BranchPythonOperator, PythonOperator, PythonVirtualenvOperator,
ShortCircuitOperator,
+ task as task_decorator,
)
from airflow.utils import timezone
from airflow.utils.session import create_session
@@ -312,6 +315,284 @@ class TestPythonOperator(TestPythonBase):
python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+class TestAirflowTaskDecorator(TestPythonBase):
+
+ def test_python_operator_python_callable_is_callable(self):
+ """Tests that @task will only instantiate if
+ the python_callable argument is callable."""
+ not_callable = {}
+ with pytest.raises(AirflowException):
+ task_decorator(not_callable, dag=self.dag)
+
+ def test_fails_bad_signature(self):
+ """Tests that @task will fail if signature is not binding."""
+ @task_decorator
+ def add_number(num: int) -> int:
+ return num + 2
+ with pytest.raises(TypeError):
+ add_number(2, 3) # pylint: disable=too-many-function-args
+ with pytest.raises(TypeError):
+ add_number() # pylint: disable=no-value-for-parameter
+ add_number('test') # pylint: disable=no-value-for-parameter
+
+ def test_fail_method(self):
+ """Tests that @task will fail if signature is not binding."""
+
+ with pytest.raises(AirflowException):
+ class Test:
+ num = 2
+
+ @task_decorator
+ def add_number(self, num: int) -> int:
+ return self.num + num
+ Test().add_number(2)
+
+ def test_fail_multiple_outputs_key_type(self):
+ @task_decorator(multiple_outputs=True)
+ def add_number(num: int):
+ return {2: num}
+ with self.dag:
+ ret = add_number(2)
+ self.dag.create_dagrun(
+ run_id=DagRunType.MANUAL.value,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ state=State.RUNNING
+ )
+
+ with pytest.raises(AirflowException):
+ # pylint: disable=maybe-no-member
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ def test_fail_multiple_outputs_no_dict(self):
+ @task_decorator(multiple_outputs=True)
+ def add_number(num: int):
+ return num
+
+ with self.dag:
+ ret = add_number(2)
+ self.dag.create_dagrun(
+ run_id=DagRunType.MANUAL.value,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ state=State.RUNNING
+ )
+
+ with pytest.raises(AirflowException):
+ # pylint: disable=maybe-no-member
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ def test_python_callable_arguments_are_templatized(self):
+ """Test @task op_args are templatized"""
+ recorded_calls = []
+
+ # Create a named tuple and ensure it is still preserved
+ # after the rendering is done
+ Named = namedtuple('Named', ['var1', 'var2'])
+ named_tuple = Named('{{ ds }}', 'unchanged')
+
+ task = task_decorator(
+ # a Mock instance cannot be used as a callable function or test
fails with a
+ # TypeError: Object of type Mock is not JSON serializable
+ build_recording_function(recorded_calls),
+ dag=self.dag)
+ ret = task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.",
named_tuple)
+
+ self.dag.create_dagrun(
+ run_id=DagRunType.MANUAL.value,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ state=State.RUNNING
+ )
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) #
pylint: disable=maybe-no-member
+
+ ds_templated = DEFAULT_DATE.date().isoformat()
+ assert len(recorded_calls) == 1
+ self._assert_calls_equal(
+ recorded_calls[0],
+ Call(4,
+ date(2019, 1, 1),
+ "dag {} ran on {}.".format(self.dag.dag_id, ds_templated),
+ Named(ds_templated, 'unchanged'))
+ )
+
+ def test_python_callable_keyword_arguments_are_templatized(self):
+ """Test PythonOperator op_kwargs are templatized"""
+ recorded_calls = []
+
+ task = task_decorator(
+ # a Mock instance cannot be used as a callable function or test
fails with a
+ # TypeError: Object of type Mock is not JSON serializable
+ build_recording_function(recorded_calls),
+ dag=self.dag
+ )
+ ret = task(an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag
{{dag.dag_id}} ran on {{ds}}.")
+ self.dag.create_dagrun(
+ run_id=DagRunType.MANUAL.value,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ state=State.RUNNING
+ )
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) #
pylint: disable=maybe-no-member
+
+ assert len(recorded_calls) == 1
+ self._assert_calls_equal(
+ recorded_calls[0],
+ Call(an_int=4,
+ a_date=date(2019, 1, 1),
+ a_templated_string="dag {} ran on {}.".format(
+ self.dag.dag_id, DEFAULT_DATE.date().isoformat()))
+ )
+
+ def test_manual_task_id(self):
+ """Test manually seting task_id"""
+
+ @task_decorator(task_id='some_name')
+ def do_run():
+ return 4
+ with self.dag:
+ do_run()
+ assert ['some_name'] == self.dag.task_ids
+
+ def test_multiple_calls(self):
+ """Test calling task multiple times in a DAG"""
+
+ @task_decorator
+ def do_run():
+ return 4
+ with self.dag:
+ do_run()
+ assert ['do_run'] == self.dag.task_ids
+ do_run_1 = do_run()
+ do_run_2 = do_run()
+ assert ['do_run', 'do_run__1', 'do_run__2'] == self.dag.task_ids
+
+ assert do_run_1.operator.task_id == 'do_run__1' # pylint:
disable=maybe-no-member
+ assert do_run_2.operator.task_id == 'do_run__2' # pylint:
disable=maybe-no-member
+
+ def test_call_20(self):
+ """Test calling decorated function 21 times in a DAG"""
+ @task_decorator
+ def __do_run():
+ return 4
+
+ with self.dag:
+ __do_run()
+ for _ in range(20):
+ __do_run()
+
+ assert self.dag.task_ids[-1] == '__do_run__20'
+
+ def test_multiple_outputs(self):
+ """Tests pushing multiple outputs as a dictionary"""
+
+ @task_decorator(multiple_outputs=True)
+ def return_dict(number: int):
+ return {
+ 'number': number + 1,
+ '43': 43
+ }
+
+ test_number = 10
+ with self.dag:
+ ret = return_dict(test_number)
+
+ dr = self.dag.create_dagrun(
+ run_id=DagRunType.MANUAL.value,
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING
+ )
+
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) #
pylint: disable=maybe-no-member
+
+ ti = dr.get_task_instances()[0]
+ assert ti.xcom_pull(key='number') == test_number + 1
+ assert ti.xcom_pull(key='43') == 43
+ assert ti.xcom_pull() == {'number': test_number + 1, '43': 43}
+
+ def test_default_args(self):
+ """Test that default_args are captured when calling the function
correctly"""
+ @task_decorator
+ def do_run():
+ return 4
+
+ with self.dag:
+ ret = do_run()
+ assert ret.operator.owner == 'airflow' # pylint:
disable=maybe-no-member
+
+ def test_xcom_arg(self):
+ """Tests that returned key in XComArg is returned correctly"""
+
+ @task_decorator
+ def add_2(number: int):
+ return number + 2
+
+ @task_decorator
+ def add_num(number: int, num2: int = 2):
+ return number + num2
+
+ test_number = 10
+
+ with self.dag:
+ bigger_number = add_2(test_number)
+ ret = add_num(bigger_number, XComArg(bigger_number.operator)) #
pylint: disable=maybe-no-member
+
+ dr = self.dag.create_dagrun(
+ run_id=DagRunType.MANUAL.value,
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING
+ )
+
+ bigger_number.operator.run( # pylint: disable=maybe-no-member
+ start_date=DEFAULT_DATE, end_date=DEFAULT_DATE
+ )
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) #
pylint: disable=maybe-no-member
+ ti_add_num = [ti for ti in dr.get_task_instances() if ti.task_id ==
'add_num'][0]
+ assert ti_add_num.xcom_pull(key=ret.key) == (test_number + 2) * 2 #
pylint: disable=maybe-no-member
+
+ def test_dag_task(self):
+ """Tests dag.task property to generate task"""
+
+ @self.dag.task
+ def add_2(number: int):
+ return number + 2
+
+ test_number = 10
+ res = add_2(test_number)
+ add_2(res)
+
+ assert 'add_2' in self.dag.task_ids
+
+ def test_dag_task_multiple_outputs(self):
+ """Tests dag.task property to generate task with multiple outputs"""
+
+ @self.dag.task(multiple_outputs=True)
+ def add_2(number: int):
+ return {'1': number + 2, '2': 42}
+
+ test_number = 10
+ add_2(test_number)
+ add_2(test_number)
+
+ assert 'add_2' in self.dag.task_ids
+
+ def test_airflow_task(self):
+ """Tests airflow.task decorator to generate task"""
+ from airflow.decorators import task
+
+ @task
+ def add_2(number: int):
+ return number + 2
+
+ test_number = 10
+ with self.dag:
+ add_2(test_number)
+
+ assert 'add_2' in self.dag.task_ids
+
+
class TestBranchOperator(unittest.TestCase):
@classmethod
def setUpClass(cls):