ashb commented on a change in pull request #15034: URL: https://github.com/apache/airflow/pull/15034#discussion_r604765458
########## File path: airflow/decorators/task_group.py ########## @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped +together when the DAG is displayed graphically. +""" +import functools +from inspect import signature +from typing import Callable, Optional, TypeVar, cast + +from airflow.utils.task_group import TaskGroup + +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + +task_group_sig = signature(TaskGroup.__init__) + + +def task_group(python_callable: Optional[Callable] = None, *tg_args, **tg_kwargs) -> Callable[[T], T]: + """ + Python TaskGroup decorator. Wraps a function into an Airflow TaskGroup. + Accepts kwargs for operator TaskGroup. Can be used to parametrize TaskGroup. + + :param python_callable: Function to decorate + :param tg_args: Arguments for TaskGroup object + :type tg_args: list + :param tg_kwargs: Kwargs for TaskGroup object. + :type tg_kwargs: dict + """ + # Setting group_id as function name if not given in kwarg group_id + + # Get dag initializer signature and bind it to validate that task_group_args, + # and task_group_kwargs are correct + + def wrapper(f: T): + if len(tg_args) == 0 and 'group_id' not in tg_kwargs.keys(): Review comment: ```suggestion if not tg_args and 'group_id' not in tg_kwargs: ``` ########## File path: airflow/decorators/task_group.py ########## @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped +together when the DAG is displayed graphically. +""" +import functools +from inspect import signature +from typing import Callable, Optional, TypeVar, cast + +from airflow.utils.task_group import TaskGroup + +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + +task_group_sig = signature(TaskGroup.__init__) + + +def task_group(python_callable: Optional[Callable] = None, *tg_args, **tg_kwargs) -> Callable[[T], T]: + """ + Python TaskGroup decorator. Wraps a function into an Airflow TaskGroup. + Accepts kwargs for operator TaskGroup. Can be used to parametrize TaskGroup. + + :param python_callable: Function to decorate + :param tg_args: Arguments for TaskGroup object + :type tg_args: list + :param tg_kwargs: Kwargs for TaskGroup object. + :type tg_kwargs: dict + """ + # Setting group_id as function name if not given in kwarg group_id + + # Get dag initializer signature and bind it to validate that task_group_args, + # and task_group_kwargs are correct Review comment: These comments are now out of place. ########## File path: tests/utils/test_task_group.py ########## @@ -576,3 +579,372 @@ def test_task_without_dag(): assert op1.dag == op2.dag == op3.dag assert dag.task_group.children.keys() == {"op1", "op2", "op3"} assert dag.task_group.children.keys() == dag.task_dict.keys() + + +# taskgroup decorator tests + + +def test_build_task_group_deco_context_manager(): + """ + Tests Following : + 1. Nested TaskGroup creation using taskgroup decorator should create same TaskGroup which can be + created using TaskGroup context manager. + 2. TaskGroup consisting Tasks created using task decorator. + 3. Node Ids of dags created with taskgroup decorator. + """ + + from airflow.operators.python import task + + # Creating Tasks + @task + def task_start(): + """Dummy Task which is First Task of Dag """ + return '[Task_start]' + + @task + def task_end(): + """Dummy Task which is Last Task of Dag""" + print('[ Task_End ]') + + @task + def task_1(value): + """ Dummy Task1""" + return f'[ Task1 {value} ]' + + @task + def task_2(value): + """ Dummy Task2""" + print(f'[ Task2 {value} ]') + + @task + def task_3(value): + """ Dummy Task3""" + return f'[ Task3 {value} ]' + + @task + def task_4(value): + """ Dummy Task3""" + print(f'[ Task4 {value} ]') + + # Creating TaskGroups + @task_group_decorator + def section_1(value): + """ TaskGroup for grouping related Tasks""" + + @task_group_decorator() + def section_2(value2): + """ TaskGroup for grouping related Tasks""" + return task_4(task_3(value2)) + + op1 = task_2(task_1(value)) + return section_2(op1) + + execution_date = pendulum.parse("20201109") + with DAG( + dag_id="example_nested_task_group_decorator", start_date=execution_date, tags=["example"] + ) as dag: + t_start = task_start() + sec_1 = section_1(t_start) + sec_1.set_downstream(task_end()) + + # Testing TaskGroup created using taskgroup decorator + assert set(dag.task_group.children.keys()) == {"task_start", "task_end", "section_1"} + assert set(dag.task_group.children['section_1'].children.keys()) == { + 'section_1.task_1', + 'section_1.task_2', + 'section_1.section_2', + } + + # Testing TaskGroup consisting Tasks created using task decorator + assert dag.task_dict['task_start'].downstream_task_ids == {'section_1.task_1'} + assert dag.task_dict['section_1.task_2'].downstream_task_ids == {'section_1.section_2.task_3'} + assert dag.task_dict['section_1.section_2.task_4'].downstream_task_ids == {'task_end'} + + # Node IDs test + node_ids = { + 'id': None, + 'children': [ + { + 'id': 'section_1', + 'children': [ + { + 'id': 'section_1.section_2', + 'children': [ + {'id': 'section_1.section_2.task_3'}, + {'id': 'section_1.section_2.task_4'}, + ], + }, + {'id': 'section_1.task_1'}, + {'id': 'section_1.task_2'}, + {'id': 'section_1.downstream_join_id'}, + ], + }, + {'id': 'task_end'}, + {'id': 'task_start'}, + ], + } + + assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids + + +def test_build_task_group_with_operators(): + """ Tests DAG with Tasks created with *Operators and TaskGroup created with taskgroup decorator """ + + from airflow.operators.python import task + + def task_start(): + """Dummy Task which is First Task of Dag """ + return '[Task_start]' + + def task_end(): + """Dummy Task which is Last Task of Dag""" + print('[ Task_End ]') + + # Creating Tasks + @task + def task_1(value): + """ Dummy Task1""" + return f'[ Task1 {value} ]' + + @task + def task_2(value): + """ Dummy Task2""" + return f'[ Task2 {value} ]' + + @task + def task_3(value): + """ Dummy Task3""" + print(f'[ Task3 {value} ]') + + # Creating TaskGroups + @task_group_decorator(group_id='section_1') + def section_a(value): + """ TaskGroup for grouping related Tasks""" + return task_3(task_2(task_1(value))) + + execution_date = pendulum.parse("20201109") + with DAG(dag_id="example_task_group_decorator_mix", start_date=execution_date, tags=["example"]) as dag: + t_start = PythonOperator(task_id='task_start', python_callable=task_start, dag=dag) + sec_1 = section_a(t_start.output) + t_end = PythonOperator(task_id='task_end', python_callable=task_end, dag=dag) + sec_1.set_downstream(t_end) + + # Testing Tasks ing DAG + assert set(dag.task_group.children.keys()) == {'section_1', 'task_start', 'task_end'} + assert set(dag.task_group.children['section_1'].children.keys()) == { + 'section_1.task_2', + 'section_1.task_3', + 'section_1.task_1', + } + + # Testing Tasks downstream + assert dag.task_dict['task_start'].downstream_task_ids == {'section_1.task_1'} + assert dag.task_dict['section_1.task_3'].downstream_task_ids == {'task_end'} + + +def test_task_group_context_mix(): + """ Test cases to check nested TaskGroup context manager with taskgroup decorator""" + + from airflow.operators.python import task + + def task_start(): + """Dummy Task which is First Task of Dag """ + return '[Task_start]' + + def task_end(): + """Dummy Task which is Last Task of Dag""" + print('[ Task_End ]') + + # Creating Tasks + @task + def task_1(value): + """ Dummy Task1""" + return f'[ Task1 {value} ]' + + @task + def task_2(value): + """ Dummy Task2""" + return f'[ Task2 {value} ]' + + @task + def task_3(value): + """ Dummy Task3""" + print(f'[ Task3 {value} ]') + + # Creating TaskGroups + @task_group_decorator + def section_2(value): + """ TaskGroup for grouping related Tasks""" + return task_3(task_2(task_1(value))) + + execution_date = pendulum.parse("20201109") + with DAG(dag_id="example_task_group_decorator_mix", start_date=execution_date, tags=["example"]) as dag: + t_start = PythonOperator(task_id='task_start', python_callable=task_start, dag=dag) + + with TaskGroup("section_1", tooltip="section_1") as section_1: + sec_2 = section_2(t_start.output) + task_s1 = DummyOperator(task_id="task_1") + task_s2 = BashOperator(task_id="task_2", bash_command='echo 1') + task_s3 = DummyOperator(task_id="task_3") + + sec_2.set_downstream(task_s1) + task_s1 >> [task_s2, task_s3] + + t_end = PythonOperator(task_id='task_end', python_callable=task_end, dag=dag) + t_start >> section_1 >> t_end + + node_ids = { + 'id': None, + 'children': [ + { + 'id': 'section_1', + 'children': [ + { + 'id': 'section_1.section_2', + 'children': [ + {'id': 'section_1.section_2.task_1'}, + {'id': 'section_1.section_2.task_2'}, + {'id': 'section_1.section_2.task_3'}, + {'id': 'section_1.section_2.downstream_join_id'}, + ], + }, + {'id': 'section_1.task_1'}, + {'id': 'section_1.task_2'}, + {'id': 'section_1.task_3'}, + {'id': 'section_1.upstream_join_id'}, + {'id': 'section_1.downstream_join_id'}, + ], + }, + {'id': 'task_end'}, + {'id': 'task_start'}, + ], + } + + assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids + + +def test_duplicate_task_group_id(): + """ Testing automatic suffix assignment for duplicate group_id""" + + from airflow.operators.python import task + + @task(task_id='start_task') + def task_start(): + """Dummy Task which is First Task of Dag """ + print('[Task_start]') + + @task(task_id='end_task') + def task_end(): + """Dummy Task which is Last Task of Dag""" + print('[Task_End]') + + # Creating Tasks + @task(task_id='task') + def task_1(): + """ Dummy Task1""" + print('[Task1]') + + @task(task_id='task') + def task_2(): + """ Dummy Task2""" + print('[Task2]') + + @task(task_id='task1') + def task_3(): + """ Dummy Task3""" + print('[Task3]') + + @task_group_decorator(group_id='task_group1') + def task_group1(): + task_start() + task_1() + task_2() + + @task_group_decorator(group_id='task_group1') + def task_group2(): + task_3() + + @task_group_decorator(group_id='task_group1') + def task_group3(): + task_end() + + execution_date = pendulum.parse("20201109") + with DAG(dag_id="example_duplicate_task_group_id", start_date=execution_date, tags=["example"]) as dag: + task_group1() + task_group2() Review comment: I would have expected this to fail: this is three _separate_ task groups, all with the same name. Calling the same TG more than once is when I would expect the `__1` suffix to be added, but having _actual_ duplicate TG ids for different groups I think should be an error case. ########## File path: tests/utils/test_task_group.py ########## @@ -576,3 +579,372 @@ def test_task_without_dag(): assert op1.dag == op2.dag == op3.dag assert dag.task_group.children.keys() == {"op1", "op2", "op3"} assert dag.task_group.children.keys() == dag.task_dict.keys() + + +# taskgroup decorator tests + + +def test_build_task_group_deco_context_manager(): + """ + Tests Following : + 1. Nested TaskGroup creation using taskgroup decorator should create same TaskGroup which can be + created using TaskGroup context manager. + 2. TaskGroup consisting Tasks created using task decorator. + 3. Node Ids of dags created with taskgroup decorator. + """ + + from airflow.operators.python import task + + # Creating Tasks + @task + def task_start(): + """Dummy Task which is First Task of Dag """ + return '[Task_start]' + + @task + def task_end(): + """Dummy Task which is Last Task of Dag""" + print('[ Task_End ]') + + @task + def task_1(value): + """ Dummy Task1""" + return f'[ Task1 {value} ]' + + @task + def task_2(value): + """ Dummy Task2""" + print(f'[ Task2 {value} ]') + + @task + def task_3(value): + """ Dummy Task3""" + return f'[ Task3 {value} ]' + + @task + def task_4(value): + """ Dummy Task3""" + print(f'[ Task4 {value} ]') + + # Creating TaskGroups + @task_group_decorator + def section_1(value): + """ TaskGroup for grouping related Tasks""" + + @task_group_decorator() + def section_2(value2): + """ TaskGroup for grouping related Tasks""" + return task_4(task_3(value2)) + + op1 = task_2(task_1(value)) + return section_2(op1) + + execution_date = pendulum.parse("20201109") + with DAG( + dag_id="example_nested_task_group_decorator", start_date=execution_date, tags=["example"] + ) as dag: + t_start = task_start() + sec_1 = section_1(t_start) + sec_1.set_downstream(task_end()) + + # Testing TaskGroup created using taskgroup decorator + assert set(dag.task_group.children.keys()) == {"task_start", "task_end", "section_1"} + assert set(dag.task_group.children['section_1'].children.keys()) == { + 'section_1.task_1', + 'section_1.task_2', + 'section_1.section_2', + } + + # Testing TaskGroup consisting Tasks created using task decorator + assert dag.task_dict['task_start'].downstream_task_ids == {'section_1.task_1'} + assert dag.task_dict['section_1.task_2'].downstream_task_ids == {'section_1.section_2.task_3'} + assert dag.task_dict['section_1.section_2.task_4'].downstream_task_ids == {'task_end'} + + # Node IDs test + node_ids = { + 'id': None, + 'children': [ + { + 'id': 'section_1', + 'children': [ + { + 'id': 'section_1.section_2', + 'children': [ + {'id': 'section_1.section_2.task_3'}, + {'id': 'section_1.section_2.task_4'}, + ], + }, + {'id': 'section_1.task_1'}, + {'id': 'section_1.task_2'}, + {'id': 'section_1.downstream_join_id'}, + ], + }, + {'id': 'task_end'}, + {'id': 'task_start'}, + ], + } + + assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids + + +def test_build_task_group_with_operators(): + """ Tests DAG with Tasks created with *Operators and TaskGroup created with taskgroup decorator """ + + from airflow.operators.python import task + + def task_start(): + """Dummy Task which is First Task of Dag """ + return '[Task_start]' + + def task_end(): + """Dummy Task which is Last Task of Dag""" + print('[ Task_End ]') + + # Creating Tasks + @task + def task_1(value): + """ Dummy Task1""" + return f'[ Task1 {value} ]' + + @task + def task_2(value): + """ Dummy Task2""" + return f'[ Task2 {value} ]' + + @task + def task_3(value): + """ Dummy Task3""" + print(f'[ Task3 {value} ]') + + # Creating TaskGroups + @task_group_decorator(group_id='section_1') + def section_a(value): + """ TaskGroup for grouping related Tasks""" + return task_3(task_2(task_1(value))) + + execution_date = pendulum.parse("20201109") + with DAG(dag_id="example_task_group_decorator_mix", start_date=execution_date, tags=["example"]) as dag: + t_start = PythonOperator(task_id='task_start', python_callable=task_start, dag=dag) + sec_1 = section_a(t_start.output) + t_end = PythonOperator(task_id='task_end', python_callable=task_end, dag=dag) + sec_1.set_downstream(t_end) + + # Testing Tasks ing DAG + assert set(dag.task_group.children.keys()) == {'section_1', 'task_start', 'task_end'} + assert set(dag.task_group.children['section_1'].children.keys()) == { + 'section_1.task_2', + 'section_1.task_3', + 'section_1.task_1', + } + + # Testing Tasks downstream + assert dag.task_dict['task_start'].downstream_task_ids == {'section_1.task_1'} + assert dag.task_dict['section_1.task_3'].downstream_task_ids == {'task_end'} + + +def test_task_group_context_mix(): + """ Test cases to check nested TaskGroup context manager with taskgroup decorator""" + + from airflow.operators.python import task + + def task_start(): + """Dummy Task which is First Task of Dag """ + return '[Task_start]' + + def task_end(): + """Dummy Task which is Last Task of Dag""" + print('[ Task_End ]') + + # Creating Tasks + @task + def task_1(value): + """ Dummy Task1""" + return f'[ Task1 {value} ]' + + @task + def task_2(value): + """ Dummy Task2""" + return f'[ Task2 {value} ]' + + @task + def task_3(value): + """ Dummy Task3""" + print(f'[ Task3 {value} ]') + + # Creating TaskGroups + @task_group_decorator + def section_2(value): + """ TaskGroup for grouping related Tasks""" + return task_3(task_2(task_1(value))) + + execution_date = pendulum.parse("20201109") + with DAG(dag_id="example_task_group_decorator_mix", start_date=execution_date, tags=["example"]) as dag: + t_start = PythonOperator(task_id='task_start', python_callable=task_start, dag=dag) + + with TaskGroup("section_1", tooltip="section_1") as section_1: + sec_2 = section_2(t_start.output) + task_s1 = DummyOperator(task_id="task_1") + task_s2 = BashOperator(task_id="task_2", bash_command='echo 1') + task_s3 = DummyOperator(task_id="task_3") + + sec_2.set_downstream(task_s1) + task_s1 >> [task_s2, task_s3] + + t_end = PythonOperator(task_id='task_end', python_callable=task_end, dag=dag) + t_start >> section_1 >> t_end + + node_ids = { + 'id': None, + 'children': [ + { + 'id': 'section_1', + 'children': [ + { + 'id': 'section_1.section_2', + 'children': [ + {'id': 'section_1.section_2.task_1'}, + {'id': 'section_1.section_2.task_2'}, + {'id': 'section_1.section_2.task_3'}, + {'id': 'section_1.section_2.downstream_join_id'}, + ], + }, + {'id': 'section_1.task_1'}, + {'id': 'section_1.task_2'}, + {'id': 'section_1.task_3'}, + {'id': 'section_1.upstream_join_id'}, + {'id': 'section_1.downstream_join_id'}, + ], + }, + {'id': 'task_end'}, + {'id': 'task_start'}, + ], + } + + assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids + + +def test_duplicate_task_group_id(): + """ Testing automatic suffix assignment for duplicate group_id""" + + from airflow.operators.python import task + + @task(task_id='start_task') + def task_start(): + """Dummy Task which is First Task of Dag """ + print('[Task_start]') + + @task(task_id='end_task') + def task_end(): + """Dummy Task which is Last Task of Dag""" + print('[Task_End]') + + # Creating Tasks + @task(task_id='task') + def task_1(): + """ Dummy Task1""" + print('[Task1]') + + @task(task_id='task') + def task_2(): + """ Dummy Task2""" + print('[Task2]') + + @task(task_id='task1') + def task_3(): + """ Dummy Task3""" + print('[Task3]') + + @task_group_decorator(group_id='task_group1') + def task_group1(): + task_start() + task_1() + task_2() + + @task_group_decorator(group_id='task_group1') + def task_group2(): + task_3() + + @task_group_decorator(group_id='task_group1') Review comment: ```suggestion @task_group_decorator('task_group1') ``` So we test passing group id via arg (instead of kwarg) ########## File path: airflow/decorators/task_group.py ########## @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped +together when the DAG is displayed graphically. +""" +import functools +from inspect import signature +from typing import Callable, Optional, TypeVar, cast + +from airflow.utils.task_group import TaskGroup + +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + +task_group_sig = signature(TaskGroup.__init__) + + +def task_group(python_callable: Optional[Callable] = None, *tg_args, **tg_kwargs) -> Callable[[T], T]: + """ + Python TaskGroup decorator. Wraps a function into an Airflow TaskGroup. + Accepts kwargs for operator TaskGroup. Can be used to parametrize TaskGroup. + + :param python_callable: Function to decorate + :param tg_args: Arguments for TaskGroup object + :type tg_args: list + :param tg_kwargs: Kwargs for TaskGroup object. + :type tg_kwargs: dict + """ + # Setting group_id as function name if not given in kwarg group_id + + # Get dag initializer signature and bind it to validate that task_group_args, + # and task_group_kwargs are correct + + def wrapper(f: T): + if len(tg_args) == 0 and 'group_id' not in tg_kwargs.keys(): + tg_kwargs['group_id'] = f.__name__ + task_group_bound_args = task_group_sig.bind_partial(*tg_args, **tg_kwargs) + f_sig = signature(f) + + @functools.wraps(f) + def factory(*args, **kwargs): + # Generate signature for decorated function and bind the arguments when called + # we do this to extract parameters so we can annotate them on the DAG object. + # In addition, this fails if we are missing any args/kwargs with TypeError as expected. + # Apply defaults to capture default values if set. + current_f_sig = f_sig.bind(*args, **kwargs) + current_f_sig.apply_defaults() Review comment: I'm a bit confused as to why we need this at all, and why we can't simply do `f(*args, **kwargs)` on L69. The comment says "so we can annotate them on the DAG object." but I don't see any reference to that in the code. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
