This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8e75e23497 Allow re-use of decorated tasks (#22941)
8e75e23497 is described below

commit 8e75e2349791ee606203d5ba9035146e8a3be3dc
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Sat Apr 16 09:14:53 2022 +0100

    Allow re-use of decorated tasks (#22941)
    
    This opens up the possibility of using one decorated task
    in different dag files.
    Take for example the below task:
    - common.py
    
    @task(task_id='hello')
    def hello():
        print('Hello')
    
    defined in a file and called in different dag files using different task 
ids:
    - dag_file1.py:
    from common import hello
    
    @dag()
    def mydag():
        for i in range(3):
            hello.override(task_id=f'myhellotask_{i}')()
    
    - dag_file2.py:
    from common import hello
    
    @dag():
    def mydag2():
        for i in range(3):
            hello.override(task_id=f'welcome_message_{i}')()
    
    They would all run with different task ids
---
 airflow/decorators/base.py                    |  3 ++
 docs/apache-airflow/tutorial_taskflow_api.rst | 58 +++++++++++++++++++++++++++
 tests/decorators/test_python.py               | 57 ++++++++++++++++++++++++++
 3 files changed, 118 insertions(+)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 9072439b2f..2029f6c5ed 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -385,6 +385,9 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
 
         return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": 
op_kwargs})
 
+    def override(self, **kwargs) -> "_TaskDecorator[Function, 
OperatorSubclass]":
+        return attr.evolve(self, kwargs={**self.kwargs, **kwargs})
+
 
 def _merge_kwargs(kwargs1: Dict[str, Any], kwargs2: Dict[str, Any], *, 
fail_reason: str) -> Dict[str, Any]:
     duplicated_keys = set(kwargs1).intersection(kwargs2)
diff --git a/docs/apache-airflow/tutorial_taskflow_api.rst 
b/docs/apache-airflow/tutorial_taskflow_api.rst
index 6040a0d192..8ff4bca852 100644
--- a/docs/apache-airflow/tutorial_taskflow_api.rst
+++ b/docs/apache-airflow/tutorial_taskflow_api.rst
@@ -160,6 +160,64 @@ the dependencies as shown below.
     :start-after: [START main_flow]
     :end-before: [END main_flow]
 
+
+Reusing a decorated task
+-------------------------
+
+Decorated tasks are flexible. You can reuse a decorated task in multiple DAGs, 
overriding the task
+parameters such as the ``task_id``, ``queue``, ``pool``, etc.
+
+Below is an example of how you can reuse a decorated task in multiple DAGs:
+
+.. code-block:: python
+
+    from airflow.decorators import task, dag
+    from datetime import datetime
+
+
+    @task
+    def add_task(x, y):
+        print(f"Task args: x={x}, y={y}")
+        return x + y
+
+
+    @dag(start_date=datetime(2022, 1, 1))
+    def mydag():
+        start = add_task.override(task_id="start")(1, 2)
+        for i in range(3):
+            start >> add_task.override(task_id=f"add_start_{i}")(start, i)
+
+
+    @dag(start_date=datetime(2022, 1, 1))
+    def mydag2():
+        start = add_task(1, 2)
+        for i in range(3):
+            start >> add_task.override(task_id=f"new_add_task_{i}")(start, i)
+
+
+    first_dag = mydag()
+    second_dag = mydag2()
+
+You can also import the above ``add_task`` and use it in another DAG file.
+Suppose the ``add_task`` code lives in a file called ``common.py``. You can do 
this:
+
+.. code-block:: python
+
+    from common import add_task
+    from airflow.decorators import dag
+    from datetime import datetime
+
+
+    @dag(start_date=datetime(2022, 1, 1))
+    def use_add_task():
+        start = add_task.override(priority_weight=3)(1, 2)
+        for i in range(3):
+            start >> add_task.override(task_id=f"new_add_task_{i}", 
retries=4)(start, i)
+
+
+    created_dag = use_add_task()
+
+
 Using the TaskFlow API with Docker or Virtual Environments
 ----------------------------------------------------------
 
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 418907fd78..3f8b44c464 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -506,6 +506,63 @@ class TestAirflowTaskDecorator:
 
         assert ret.operator.doc_md.strip(), "Adds 2 to number."
 
+    def test_user_provided_task_id_in_a_loop_is_used(self):
+        """Tests that when looping that user provided task_id is used"""
+
+        @task_decorator(task_id='hello_task')
+        def hello():
+            """
+            Print Hello world
+            """
+            print("Hello world")
+
+        with self.dag:
+            for i in range(3):
+                hello.override(task_id=f'my_task_id_{i * 2}')()
+            hello()  # This task would have hello_task as the task_id
+
+        assert self.dag.task_ids == ['my_task_id_0', 'my_task_id_2', 
'my_task_id_4', 'hello_task']
+
+    def test_user_provided_pool_and_priority_weight_works(self):
+        """Tests that when looping that user provided pool, priority_weight 
etc is used"""
+
+        @task_decorator(task_id='hello_task')
+        def hello():
+            """
+            Print Hello world
+            """
+            print("Hello world")
+
+        with self.dag:
+            for i in range(3):
+                hello.override(pool='my_pool', priority_weight=i)()
+
+        weights = []
+        for task in self.dag.tasks:
+            assert task.pool == 'my_pool'
+            weights.append(task.priority_weight)
+        assert weights == [0, 1, 2]
+
+    def test_python_callable_args_work_as_well_as_baseoperator_args(self):
+        """Tests that when looping that user provided pool, priority_weight 
etc is used"""
+
+        @task_decorator(task_id='hello_task')
+        def hello(x, y):
+            """
+            Print Hello world
+            """
+            print("Hello world", x, y)
+            return x, y
+
+        with self.dag:
+            output = hello.override(task_id='mytask')(x=2, y=3)
+            output2 = hello.override()(2, 3)  # nothing overridden but should 
work
+
+        assert output.operator.op_kwargs == {'x': 2, 'y': 3}
+        assert output2.operator.op_args == (2, 3)
+        output.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        output2.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
 
 def test_mapped_decorator_shadow_context() -> None:
     @task_decorator

Reply via email to