ephraimbuddy commented on issue #25200:
URL: https://github.com/apache/airflow/issues/25200#issuecomment-1193272894
Here's a solution:
```diff
diff --git a/airflow/models/mappedoperator.py
b/airflow/models/mappedoperator.py
index a883ff2404..1617944013 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -665,6 +665,12 @@ class MappedOperator(AbstractOperator):
total_length,
)
unmapped_ti.state = TaskInstanceState.SKIPPED
+ # Skip the downstream tasks as well since they'd eventually
be skipped
+ session.query(TaskInstance).filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.run_id == run_id,
+ TaskInstance.task_id.in_(self.downstream_task_ids)
+ ).update({TaskInstance.state:
TaskInstanceState.SKIPPED}, synchronize_session='fetch')
else:
# Otherwise convert this into the first mapped index, and
create
# TaskInstance for other indexes.
diff --git a/tests/models/test_mappedoperator.py
b/tests/models/test_mappedoperator.py
index 09ab87524b..8c9700ebde 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -206,6 +206,34 @@ def
test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session):
assert indices == [(-1, TaskInstanceState.SKIPPED)]
+def
test_downstream_tis_skipped_when_expand_mapped_ti_skipped_on_zero_length(dag_maker,
session):
+ """"
+ Test that when expand_mapped_task skips task instance on having a zero
expansion length,
+ the downstream task instances are skipped as well.
+ """
+ with dag_maker(session=session):
+ task1 = BaseOperator(task_id="op1")
+ mapped = MockOperator.partial(task_id='task_2').expand(arg2=[])
+ task2 = BaseOperator(task_id="op2")
+ task1 >> mapped >> task2
+
+ dr = dag_maker.create_dagrun()
+
+ expand_mapped_task(mapped, dr.run_id, task1.task_id, length=0,
session=session)
+
+ indices = (
+ session.query(TaskInstance.map_index, TaskInstance.state)
+ .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id,
run_id=dr.run_id)
+ .order_by(TaskInstance.map_index)
+ .all()
+ )
+
+ assert indices == [(-1, TaskInstanceState.SKIPPED)]
+ # assert that task2 is skipped as well
+ task2 = session.query(TaskInstance).filter_by(task_id=task2.task_id,
dag_id=task2.dag_id, run_id=dr.run_id).first()
+ assert task2.state == TaskInstanceState.SKIPPED
+
+
def test_mapped_task_applies_default_args_classic(dag_maker):
with dag_maker(default_args={"execution_timeout":
timedelta(minutes=30)}) as dag:
```
However, I'm worried that trigger rules are not respected
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]