This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-3-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 4d1f6004dedf899e385bd85300e796e2feecd677 Author: Jed Cunningham <[email protected]> AuthorDate: Wed May 4 13:02:09 2022 -0600 Fix literal cross product expansion (#23434) (cherry picked from commit 3fb8e0b0b4e8810bedece873949871a94dd7387a) --- airflow/models/mappedoperator.py | 5 ++++- tests/models/test_taskinstance.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index aa51a73454..b63e26ec9e 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -795,7 +795,10 @@ class MappedOperator(AbstractOperator): if not isinstance(value, MAPPABLE_LITERAL_TYPES): # None literal type encountered, so give up return None - total += len(value) + if total == 0: + total = len(value) + else: + total *= len(value) return total @cache diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index e53b52e11b..dac987e431 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2689,6 +2689,7 @@ class TestMappedTaskInstanceReceiveValue: ti.run() show_task = dag.get_task("show") + assert show_task.parse_time_mapped_ti_count is None mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, session=session) assert num == len(mapped_tis) == 4 @@ -2697,6 +2698,41 @@ class TestMappedTaskInstanceReceiveValue: ti.run() assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)] + def test_map_literal_cross_product(self, dag_maker, session): + """Test a mapped task with literal cross product args expand properly.""" + outputs = [] + + with dag_maker(dag_id="product_same_types", session=session) as dag: + + @dag.task + def show(a, b): + outputs.append((a, b)) + + show.expand(a=[2, 4, 8], b=[5, 10]) + + dag_run = dag_maker.create_dagrun() + + show_task = dag.get_task("show") + assert show_task.parse_time_mapped_ti_count == 6 + mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, session=session) + assert len(mapped_tis) == 0 # Expanded at parse! + assert num == 6 + + tis = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.task_id == 'show', + TaskInstance.run_id == dag_run.run_id, + ) + .order_by(TaskInstance.map_index) + .all() + ) + for ti in tis: + ti.refresh_from_task(show_task) + ti.run() + assert outputs == [(2, 5), (2, 10), (4, 5), (4, 10), (8, 5), (8, 10)] + def test_map_in_group(self, tmp_path: pathlib.Path, dag_maker, session): out = tmp_path.joinpath("out") out.touch()
