ashb commented on a change in pull request #13479:
URL: https://github.com/apache/airflow/pull/13479#discussion_r577929436
##########
File path: tests/utils/test_task_group.py
##########
@@ -576,3 +541,372 @@ def test_task_without_dag():
assert op1.dag == op2.dag == op3.dag
assert dag.task_group.children.keys() == {"op1", "op2", "op3"}
assert dag.task_group.children.keys() == dag.task_dict.keys()
+
+
+# taskgroup decorator tests
+
+
+def test_build_task_group_deco_context_manager():
+ """
+ Tests Following :
+ 1. Nested TaskGroup creation using taskgroup decorator should create same
TaskGroup which can be
+ created using TaskGroup context manager.
+ 2. TaskGroup consisting Tasks created using task decorator.
+ 3. Node Ids of dags created with taskgroup decorator.
+ """
+
+ from airflow.operators.python import task
+
+ # Creating Tasks
+ @task
+ def task_start():
+ """Dummy Task which is First Task of Dag """
+ return '[Task_start]'
+
+ @task
+ def task_end():
+ """Dummy Task which is Last Task of Dag"""
+ print('[ Task_End ]')
+
+ @task
+ def task_1(value):
+ """ Dummy Task1"""
+ return f'[ Task1 {value} ]'
+
+ @task
+ def task_2(value):
+ """ Dummy Task2"""
+ print(f'[ Task2 {value} ]')
+
+ @task
+ def task_3(value):
+ """ Dummy Task3"""
+ return f'[ Task3 {value} ]'
+
+ @task
+ def task_4(value):
+ """ Dummy Task3"""
+ print(f'[ Task4 {value} ]')
+
+ # Creating TaskGroups
+ @taskgroup
+ def section_1(value):
+ """ TaskGroup for grouping related Tasks"""
+
+ @taskgroup()
+ def section_2(value2):
+ """ TaskGroup for grouping related Tasks"""
+ return task_4(task_3(value2))
+
+ op1 = task_2(task_1(value))
+ return section_2(op1)
+
+ execution_date = pendulum.parse("20201109")
+ with DAG(
+ dag_id="example_nested_task_group_decorator",
start_date=execution_date, tags=["example"]
+ ) as dag:
+ t_start = task_start()
+ sec_1 = section_1(t_start)
+ sec_1.set_downstream(task_end())
+
+ # Testing TaskGroup created using taskgroup decorator
+ assert set(dag.task_group.children.keys()) == {"task_start", "task_end",
"section_1"}
+ assert set(dag.task_group.children['section_1'].children.keys()) == {
+ 'section_1.task_1',
+ 'section_1.task_2',
+ 'section_1.section_2',
+ }
+
+ # Testing TaskGroup consisting Tasks created using task decorator
+ assert dag.task_dict['task_start'].downstream_task_ids ==
{'section_1.task_1'}
+ assert dag.task_dict['section_1.task_2'].downstream_task_ids ==
{'section_1.section_2.task_3'}
+ assert dag.task_dict['section_1.section_2.task_4'].downstream_task_ids ==
{'task_end'}
+
+ # Node IDs test
+ node_ids = {
+ 'id': None,
+ 'children': [
+ {
+ 'id': 'section_1',
+ 'children': [
+ {
+ 'id': 'section_1.section_2',
+ 'children': [
+ {'id': 'section_1.section_2.task_3'},
+ {'id': 'section_1.section_2.task_4'},
+ ],
+ },
+ {'id': 'section_1.task_1'},
+ {'id': 'section_1.task_2'},
+ {'id': 'section_1.downstream_join_id'},
+ ],
+ },
+ {'id': 'task_end'},
+ {'id': 'task_start'},
+ ],
+ }
+
+ assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids
+
+
+def test_build_task_group_with_operators():
+ """ Tests DAG with Tasks created with *Operators and TaskGroup created
with taskgroup decorator """
+
+ from airflow.operators.python import task
+
+ def task_start():
+ """Dummy Task which is First Task of Dag """
+ return '[Task_start]'
+
+ def task_end():
+ """Dummy Task which is Last Task of Dag"""
+ print('[ Task_End ]')
+
+ # Creating Tasks
+ @task
+ def task_1(value):
+ """ Dummy Task1"""
+ return f'[ Task1 {value} ]'
+
+ @task
+ def task_2(value):
+ """ Dummy Task2"""
+ return f'[ Task2 {value} ]'
+
+ @task
+ def task_3(value):
+ """ Dummy Task3"""
+ print(f'[ Task3 {value} ]')
+
+ # Creating TaskGroups
+ @taskgroup(group_id='section_1')
+ def section_a(value):
+ """ TaskGroup for grouping related Tasks"""
+ return task_3(task_2(task_1(value)))
+
+ execution_date = pendulum.parse("20201109")
+ with DAG(dag_id="example_task_group_decorator_mix",
start_date=execution_date, tags=["example"]) as dag:
+ t_start = PythonOperator(task_id='task_start',
python_callable=task_start, dag=dag)
+ sec_1 = section_a(t_start.output)
+ t_end = PythonOperator(task_id='task_end', python_callable=task_end,
dag=dag)
+ sec_1.set_downstream(t_end)
+
+ # Testing Tasks ing DAG
+ assert set(dag.task_group.children.keys()) == {'section_1', 'task_start',
'task_end'}
+ assert set(dag.task_group.children['section_1'].children.keys()) == {
+ 'section_1.task_2',
+ 'section_1.task_3',
+ 'section_1.task_1',
+ }
+
+ # Testing Tasks downstream
+ assert dag.task_dict['task_start'].downstream_task_ids ==
{'section_1.task_1'}
+ assert dag.task_dict['section_1.task_3'].downstream_task_ids ==
{'task_end'}
+
+
+def test_task_group_context_mix():
+ """ Test cases to check nested TaskGroup context manager with taskgroup
decorator"""
+
+ from airflow.operators.python import task
+
+ def task_start():
+ """Dummy Task which is First Task of Dag """
+ return '[Task_start]'
+
+ def task_end():
+ """Dummy Task which is Last Task of Dag"""
+ print('[ Task_End ]')
+
+ # Creating Tasks
+ @task
+ def task_1(value):
+ """ Dummy Task1"""
+ return f'[ Task1 {value} ]'
+
+ @task
+ def task_2(value):
+ """ Dummy Task2"""
+ return f'[ Task2 {value} ]'
+
+ @task
+ def task_3(value):
+ """ Dummy Task3"""
+ print(f'[ Task3 {value} ]')
+
+ # Creating TaskGroups
+ @taskgroup
+ def section_2(value):
+ """ TaskGroup for grouping related Tasks"""
+ return task_3(task_2(task_1(value)))
+
+ execution_date = pendulum.parse("20201109")
+ with DAG(dag_id="example_task_group_decorator_mix",
start_date=execution_date, tags=["example"]) as dag:
+ t_start = PythonOperator(task_id='task_start',
python_callable=task_start, dag=dag)
+
+ with TaskGroup("section_1", tooltip="section_1") as section_1:
+ sec_2 = section_2(t_start.output)
+ task_s1 = DummyOperator(task_id="task_1")
+ task_s2 = BashOperator(task_id="task_2", bash_command='echo 1')
+ task_s3 = DummyOperator(task_id="task_3")
+
+ sec_2.set_downstream(task_s1)
+ task_s1 >> [task_s2, task_s3]
+
+ t_end = PythonOperator(task_id='task_end', python_callable=task_end,
dag=dag)
+ t_start >> section_1 >> t_end
+
+ node_ids = {
+ 'id': None,
+ 'children': [
+ {
+ 'id': 'section_1',
+ 'children': [
+ {
+ 'id': 'section_1.section_2',
+ 'children': [
+ {'id': 'section_1.section_2.task_1'},
+ {'id': 'section_1.section_2.task_2'},
+ {'id': 'section_1.section_2.task_3'},
+ {'id': 'section_1.section_2.downstream_join_id'},
+ ],
+ },
+ {'id': 'section_1.task_1'},
+ {'id': 'section_1.task_2'},
+ {'id': 'section_1.task_3'},
+ {'id': 'section_1.upstream_join_id'},
+ {'id': 'section_1.downstream_join_id'},
+ ],
+ },
+ {'id': 'task_end'},
+ {'id': 'task_start'},
+ ],
+ }
+
+ assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids
+
+
+def test_duplicate_task_group_id():
+ """ Testing automatic suffix assignment for duplicate group_id"""
+
+ from airflow.operators.python import task
+
+ @task(task_id='start_task')
+ def task_start():
+ """Dummy Task which is First Task of Dag """
+ print('[Task_start]')
+
+ @task(task_id='end_task')
+ def task_end():
+ """Dummy Task which is Last Task of Dag"""
+ print('[Task_End]')
+
+ # Creating Tasks
+ @task(task_id='task')
+ def task_1():
+ """ Dummy Task1"""
+ print('[Task1]')
+
+ @task(task_id='task')
+ def task_2():
+ """ Dummy Task2"""
+ print('[Task2]')
+
+ @task(task_id='task1')
+ def task_3():
+ """ Dummy Task3"""
+ print('[Task3]')
+
+ @taskgroup(group_id='task_group1')
+ def task_group1():
+ task_start()
+ task_1()
+ task_2()
+
+ @taskgroup(group_id='task_group1')
+ def task_group2():
+ task_3()
+
+ @taskgroup(group_id='task_group1')
+ def task_group3():
+ task_end()
+
+ execution_date = pendulum.parse("20201109")
+ with DAG(dag_id="example_duplicate_task_group_id",
start_date=execution_date, tags=["example"]) as dag:
+ task_group1()
+ task_group2()
+ task_group3()
Review comment:
We can't check what task group decorator created the previous function,
but we could check if the _current_ TG was created by the current decorated
function, and if not then it would fail.
We could keep a set of TG objects in the wrapper fn, and if the TG already
exists on the DAG, and the TG is not in that set, then it was created by some
other decorator -> raise an error.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]