This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 4d1f6004dedf899e385bd85300e796e2feecd677
Author: Jed Cunningham <[email protected]>
AuthorDate: Wed May 4 13:02:09 2022 -0600

    Fix literal cross product expansion (#23434)
    
    (cherry picked from commit 3fb8e0b0b4e8810bedece873949871a94dd7387a)
---
 airflow/models/mappedoperator.py  |  5 ++++-
 tests/models/test_taskinstance.py | 36 ++++++++++++++++++++++++++++++++++++
 2 files changed, 40 insertions(+), 1 deletion(-)

diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index aa51a73454..b63e26ec9e 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -795,7 +795,10 @@ class MappedOperator(AbstractOperator):
             if not isinstance(value, MAPPABLE_LITERAL_TYPES):
                 # None literal type encountered, so give up
                 return None
-            total += len(value)
+            if total == 0:
+                total = len(value)
+            else:
+                total *= len(value)
         return total
 
     @cache
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index e53b52e11b..dac987e431 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -2689,6 +2689,7 @@ class TestMappedTaskInstanceReceiveValue:
         ti.run()
 
         show_task = dag.get_task("show")
+        assert show_task.parse_time_mapped_ti_count is None
         mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, 
session=session)
         assert num == len(mapped_tis) == 4
 
@@ -2697,6 +2698,41 @@ class TestMappedTaskInstanceReceiveValue:
             ti.run()
         assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)]
 
+    def test_map_literal_cross_product(self, dag_maker, session):
+        """Test a mapped task with literal cross product args expand 
properly."""
+        outputs = []
+
+        with dag_maker(dag_id="product_same_types", session=session) as dag:
+
+            @dag.task
+            def show(a, b):
+                outputs.append((a, b))
+
+            show.expand(a=[2, 4, 8], b=[5, 10])
+
+        dag_run = dag_maker.create_dagrun()
+
+        show_task = dag.get_task("show")
+        assert show_task.parse_time_mapped_ti_count == 6
+        mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, 
session=session)
+        assert len(mapped_tis) == 0  # Expanded at parse!
+        assert num == 6
+
+        tis = (
+            session.query(TaskInstance)
+            .filter(
+                TaskInstance.dag_id == dag.dag_id,
+                TaskInstance.task_id == 'show',
+                TaskInstance.run_id == dag_run.run_id,
+            )
+            .order_by(TaskInstance.map_index)
+            .all()
+        )
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == [(2, 5), (2, 10), (4, 5), (4, 10), (8, 5), (8, 10)]
+
     def test_map_in_group(self, tmp_path: pathlib.Path, dag_maker, session):
         out = tmp_path.joinpath("out")
         out.touch()

Reply via email to