This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 70b41e46b4 Move MappedOperator tests to mirror code location (#23884)
70b41e46b4 is described below

commit 70b41e46b46e65c0446a40ab91624cb2291a5039
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Tue May 24 11:28:54 2022 +0100

    Move MappedOperator tests to mirror code location (#23884)
    
    At some point during the development of AIP-42 we moved the code for
    MappedOperator out of baseoperator.py to mappedoperator.py, but we
    didn't move the tests at the same time
---
 tests/models/test_baseoperator.py   | 252 --------------------------------
 tests/models/test_mappedoperator.py | 278 ++++++++++++++++++++++++++++++++++++
 2 files changed, 278 insertions(+), 252 deletions(-)

diff --git a/tests/models/test_baseoperator.py 
b/tests/models/test_baseoperator.py
index 5ba271a5a1..8c75c86ed4 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -22,7 +22,6 @@ from typing import Any, NamedTuple
 from unittest import mock
 
 import jinja2
-import pendulum
 import pytest
 
 from airflow.decorators import task as task_decorator
@@ -30,20 +29,13 @@ from airflow.exceptions import AirflowException
 from airflow.lineage.entities import File
 from airflow.models import DAG
 from airflow.models.baseoperator import BaseOperator, BaseOperatorMeta, chain, 
cross_downstream
-from airflow.models.mappedoperator import MappedOperator
-from airflow.models.taskinstance import TaskInstance
-from airflow.models.taskmap import TaskMap
-from airflow.models.xcom import XCOM_RETURN_KEY
-from airflow.models.xcom_arg import XComArg
 from airflow.utils.context import Context
 from airflow.utils.edgemodifier import Label
-from airflow.utils.state import TaskInstanceState
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.weight_rule import WeightRule
 from tests.models import DEFAULT_DATE
 from tests.test_utils.config import conf_vars
-from tests.test_utils.mapping import expand_mapped_task
 from tests.test_utils.mock_operators import DeprecatedOperator, MockOperator
 
 
@@ -752,250 +744,6 @@ def test_operator_retries(caplog, dag_maker, retries, 
expected):
     assert caplog.record_tuples == expected
 
 
-def test_task_mapping_with_dag():
-    with DAG("test-dag", start_date=DEFAULT_DATE) as dag:
-        task1 = BaseOperator(task_id="op1")
-        literal = ['a', 'b', 'c']
-        mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal)
-        finish = MockOperator(task_id="finish")
-
-        task1 >> mapped >> finish
-
-    assert task1.downstream_list == [mapped]
-    assert mapped in dag.tasks
-    assert mapped.task_group == dag.task_group
-    # At parse time there should only be three tasks!
-    assert len(dag.tasks) == 3
-
-    assert finish.upstream_list == [mapped]
-    assert mapped.downstream_list == [finish]
-
-
-def test_task_mapping_without_dag_context():
-    with DAG("test-dag", start_date=DEFAULT_DATE) as dag:
-        task1 = BaseOperator(task_id="op1")
-    literal = ['a', 'b', 'c']
-    mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal)
-
-    task1 >> mapped
-
-    assert isinstance(mapped, MappedOperator)
-    assert mapped in dag.tasks
-    assert task1.downstream_list == [mapped]
-    assert mapped in dag.tasks
-    # At parse time there should only be two tasks!
-    assert len(dag.tasks) == 2
-
-
-def test_task_mapping_default_args():
-    default_args = {'start_date': DEFAULT_DATE.now(), 'owner': 'test'}
-    with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args):
-        task1 = BaseOperator(task_id="op1")
-        literal = ['a', 'b', 'c']
-        mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal)
-
-        task1 >> mapped
-
-    assert mapped.partial_kwargs['owner'] == 'test'
-    assert mapped.start_date == pendulum.instance(default_args['start_date'])
-
-
-def test_map_unknown_arg_raises():
-    with pytest.raises(TypeError, match=r"argument 'file'"):
-        BaseOperator.partial(task_id='a').expand(file=[1, 2, {'a': 'b'}])
-
-
-def test_map_xcom_arg():
-    """Test that dependencies are correct when mapping with an XComArg"""
-    with DAG("test-dag", start_date=DEFAULT_DATE):
-        task1 = BaseOperator(task_id="op1")
-        mapped = 
MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
-        finish = MockOperator(task_id="finish")
-
-        mapped >> finish
-
-    assert task1.downstream_list == [mapped]
-
-
-def test_partial_on_instance() -> None:
-    """`.partial` on an instance should fail -- it's only designed to be 
called on classes"""
-    with pytest.raises(TypeError):
-        MockOperator(
-            task_id='a',
-        ).partial()
-
-
-def test_partial_on_class() -> None:
-    # Test that we accept args for superclasses too
-    op = MockOperator.partial(task_id='a', arg1="a", 
trigger_rule=TriggerRule.ONE_FAILED)
-    assert op.kwargs["arg1"] == "a"
-    assert op.kwargs["trigger_rule"] == TriggerRule.ONE_FAILED
-
-
-def test_partial_on_class_invalid_ctor_args() -> None:
-    """Test that when we pass invalid args to partial().
-
-    I.e. if an arg is not known on the class or any of its parent classes we 
error at parse time
-    """
-    with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"):
-        MockOperator.partial(task_id='a', foo='bar', bar=2)
-
-
[email protected](
-    ["num_existing_tis", "expected"],
-    (
-        pytest.param(0, [(0, None), (1, None), (2, None)], 
id='only-unmapped-ti-exists'),
-        pytest.param(
-            3,
-            [(0, 'success'), (1, 'success'), (2, 'success')],
-            id='all-tis-exist',
-        ),
-        pytest.param(
-            5,
-            [
-                (0, 'success'),
-                (1, 'success'),
-                (2, 'success'),
-                (3, TaskInstanceState.REMOVED),
-                (4, TaskInstanceState.REMOVED),
-            ],
-            id="tis-to-be-removed",
-        ),
-    ),
-)
-def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, 
expected):
-    literal = [1, 2, {'a': 'b'}]
-    with dag_maker(session=session):
-        task1 = BaseOperator(task_id="op1")
-        mapped = 
MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
-
-    dr = dag_maker.create_dagrun()
-
-    session.add(
-        TaskMap(
-            dag_id=dr.dag_id,
-            task_id=task1.task_id,
-            run_id=dr.run_id,
-            map_index=-1,
-            length=len(literal),
-            keys=None,
-        )
-    )
-
-    if num_existing_tis:
-        # Remove the map_index=-1 TI when we're creating other TIs
-        session.query(TaskInstance).filter(
-            TaskInstance.dag_id == mapped.dag_id,
-            TaskInstance.task_id == mapped.task_id,
-            TaskInstance.run_id == dr.run_id,
-        ).delete()
-
-    for index in range(num_existing_tis):
-        # Give the existing TIs a state to make sure we don't change them
-        ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, 
state=TaskInstanceState.SUCCESS)
-        session.add(ti)
-    session.flush()
-
-    mapped.expand_mapped_task(dr.run_id, session=session)
-
-    indices = (
-        session.query(TaskInstance.map_index, TaskInstance.state)
-        .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, 
run_id=dr.run_id)
-        .order_by(TaskInstance.map_index)
-        .all()
-    )
-
-    assert indices == expected
-
-
-def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session):
-    with dag_maker(session=session):
-        task1 = BaseOperator(task_id="op1")
-        mapped = 
MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
-
-    dr = dag_maker.create_dagrun()
-
-    expand_mapped_task(mapped, dr.run_id, task1.task_id, length=0, 
session=session)
-
-    indices = (
-        session.query(TaskInstance.map_index, TaskInstance.state)
-        .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, 
run_id=dr.run_id)
-        .order_by(TaskInstance.map_index)
-        .all()
-    )
-
-    assert indices == [(-1, TaskInstanceState.SKIPPED)]
-
-
-def test_mapped_task_applies_default_args_classic(dag_maker):
-    with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) 
as dag:
-        MockOperator(task_id="simple", arg1=None, arg2=0)
-        MockOperator.partial(task_id="mapped").expand(arg1=[1], arg2=[2, 3])
-
-    assert dag.get_task("simple").execution_timeout == timedelta(minutes=30)
-    assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30)
-
-
-def test_mapped_task_applies_default_args_taskflow(dag_maker):
-    with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) 
as dag:
-
-        @dag.task
-        def simple(arg):
-            pass
-
-        @dag.task
-        def mapped(arg):
-            pass
-
-        simple(arg=0)
-        mapped.expand(arg=[1, 2])
-
-    assert dag.get_task("simple").execution_timeout == timedelta(minutes=30)
-    assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30)
-
-
-def test_mapped_render_template_fields_validating_operator(dag_maker, session):
-    class MyOperator(MockOperator):
-        def __init__(self, value, arg1, **kwargs):
-            assert isinstance(value, str), "value should have been resolved 
before unmapping"
-            assert isinstance(arg1, str), "value should have been resolved 
before unmapping"
-            super().__init__(arg1=arg1, **kwargs)
-            self.value = value
-
-    with dag_maker(session=session):
-        task1 = BaseOperator(task_id="op1")
-        xcom_arg = XComArg(task1)
-        mapped = MyOperator.partial(task_id='a', arg2='{{ ti.task_id 
}}').expand(
-            value=xcom_arg, arg1=xcom_arg
-        )
-
-    dr = dag_maker.create_dagrun()
-    ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
-
-    ti.xcom_push(key=XCOM_RETURN_KEY, value=['{{ ds }}'], session=session)
-
-    session.add(
-        TaskMap(
-            dag_id=dr.dag_id,
-            task_id=task1.task_id,
-            run_id=dr.run_id,
-            map_index=-1,
-            length=1,
-            keys=None,
-        )
-    )
-    session.flush()
-
-    mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, 
session=session)
-    mapped_ti.map_index = 0
-    op = 
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
-    assert isinstance(op, MyOperator)
-
-    assert op.value == "{{ ds }}", "Should not be templated!"
-    assert op.arg1 == "{{ ds }}"
-    assert op.arg2 == "a"
-
-
 def test_default_retry_delay(dag_maker):
     with dag_maker(dag_id='test_default_retry_delay'):
         task1 = BaseOperator(task_id='test_no_explicit_retry_delay')
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
new file mode 100644
index 0000000000..c720fd96d9
--- /dev/null
+++ b/tests/models/test_mappedoperator.py
@@ -0,0 +1,278 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from datetime import timedelta
+
+import pendulum
+import pytest
+
+from airflow.models import DAG
+from airflow.models.baseoperator import BaseOperator
+from airflow.models.mappedoperator import MappedOperator
+from airflow.models.taskinstance import TaskInstance
+from airflow.models.taskmap import TaskMap
+from airflow.models.xcom import XCOM_RETURN_KEY
+from airflow.models.xcom_arg import XComArg
+from airflow.utils.state import TaskInstanceState
+from airflow.utils.trigger_rule import TriggerRule
+from tests.models import DEFAULT_DATE
+from tests.test_utils.mapping import expand_mapped_task
+from tests.test_utils.mock_operators import MockOperator
+
+
+def test_task_mapping_with_dag():
+    with DAG("test-dag", start_date=DEFAULT_DATE) as dag:
+        task1 = BaseOperator(task_id="op1")
+        literal = ['a', 'b', 'c']
+        mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal)
+        finish = MockOperator(task_id="finish")
+
+        task1 >> mapped >> finish
+
+    assert task1.downstream_list == [mapped]
+    assert mapped in dag.tasks
+    assert mapped.task_group == dag.task_group
+    # At parse time there should only be three tasks!
+    assert len(dag.tasks) == 3
+
+    assert finish.upstream_list == [mapped]
+    assert mapped.downstream_list == [finish]
+
+
+def test_task_mapping_without_dag_context():
+    with DAG("test-dag", start_date=DEFAULT_DATE) as dag:
+        task1 = BaseOperator(task_id="op1")
+    literal = ['a', 'b', 'c']
+    mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal)
+
+    task1 >> mapped
+
+    assert isinstance(mapped, MappedOperator)
+    assert mapped in dag.tasks
+    assert task1.downstream_list == [mapped]
+    assert mapped in dag.tasks
+    # At parse time there should only be two tasks!
+    assert len(dag.tasks) == 2
+
+
+def test_task_mapping_default_args():
+    default_args = {'start_date': DEFAULT_DATE.now(), 'owner': 'test'}
+    with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args):
+        task1 = BaseOperator(task_id="op1")
+        literal = ['a', 'b', 'c']
+        mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal)
+
+        task1 >> mapped
+
+    assert mapped.partial_kwargs['owner'] == 'test'
+    assert mapped.start_date == pendulum.instance(default_args['start_date'])
+
+
+def test_map_unknown_arg_raises():
+    with pytest.raises(TypeError, match=r"argument 'file'"):
+        BaseOperator.partial(task_id='a').expand(file=[1, 2, {'a': 'b'}])
+
+
+def test_map_xcom_arg():
+    """Test that dependencies are correct when mapping with an XComArg"""
+    with DAG("test-dag", start_date=DEFAULT_DATE):
+        task1 = BaseOperator(task_id="op1")
+        mapped = 
MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
+        finish = MockOperator(task_id="finish")
+
+        mapped >> finish
+
+    assert task1.downstream_list == [mapped]
+
+
+def test_partial_on_instance() -> None:
+    """`.partial` on an instance should fail -- it's only designed to be 
called on classes"""
+    with pytest.raises(TypeError):
+        MockOperator(
+            task_id='a',
+        ).partial()
+
+
+def test_partial_on_class() -> None:
+    # Test that we accept args for superclasses too
+    op = MockOperator.partial(task_id='a', arg1="a", 
trigger_rule=TriggerRule.ONE_FAILED)
+    assert op.kwargs["arg1"] == "a"
+    assert op.kwargs["trigger_rule"] == TriggerRule.ONE_FAILED
+
+
+def test_partial_on_class_invalid_ctor_args() -> None:
+    """Test that when we pass invalid args to partial().
+
+    I.e. if an arg is not known on the class or any of its parent classes we 
error at parse time
+    """
+    with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"):
+        MockOperator.partial(task_id='a', foo='bar', bar=2)
+
+
[email protected](
+    ["num_existing_tis", "expected"],
+    (
+        pytest.param(0, [(0, None), (1, None), (2, None)], 
id='only-unmapped-ti-exists'),
+        pytest.param(
+            3,
+            [(0, 'success'), (1, 'success'), (2, 'success')],
+            id='all-tis-exist',
+        ),
+        pytest.param(
+            5,
+            [
+                (0, 'success'),
+                (1, 'success'),
+                (2, 'success'),
+                (3, TaskInstanceState.REMOVED),
+                (4, TaskInstanceState.REMOVED),
+            ],
+            id="tis-to-be-removed",
+        ),
+    ),
+)
+def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, 
expected):
+    literal = [1, 2, {'a': 'b'}]
+    with dag_maker(session=session):
+        task1 = BaseOperator(task_id="op1")
+        mapped = 
MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
+
+    dr = dag_maker.create_dagrun()
+
+    session.add(
+        TaskMap(
+            dag_id=dr.dag_id,
+            task_id=task1.task_id,
+            run_id=dr.run_id,
+            map_index=-1,
+            length=len(literal),
+            keys=None,
+        )
+    )
+
+    if num_existing_tis:
+        # Remove the map_index=-1 TI when we're creating other TIs
+        session.query(TaskInstance).filter(
+            TaskInstance.dag_id == mapped.dag_id,
+            TaskInstance.task_id == mapped.task_id,
+            TaskInstance.run_id == dr.run_id,
+        ).delete()
+
+    for index in range(num_existing_tis):
+        # Give the existing TIs a state to make sure we don't change them
+        ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, 
state=TaskInstanceState.SUCCESS)
+        session.add(ti)
+    session.flush()
+
+    mapped.expand_mapped_task(dr.run_id, session=session)
+
+    indices = (
+        session.query(TaskInstance.map_index, TaskInstance.state)
+        .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, 
run_id=dr.run_id)
+        .order_by(TaskInstance.map_index)
+        .all()
+    )
+
+    assert indices == expected
+
+
+def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session):
+    with dag_maker(session=session):
+        task1 = BaseOperator(task_id="op1")
+        mapped = 
MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
+
+    dr = dag_maker.create_dagrun()
+
+    expand_mapped_task(mapped, dr.run_id, task1.task_id, length=0, 
session=session)
+
+    indices = (
+        session.query(TaskInstance.map_index, TaskInstance.state)
+        .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, 
run_id=dr.run_id)
+        .order_by(TaskInstance.map_index)
+        .all()
+    )
+
+    assert indices == [(-1, TaskInstanceState.SKIPPED)]
+
+
+def test_mapped_task_applies_default_args_classic(dag_maker):
+    with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) 
as dag:
+        MockOperator(task_id="simple", arg1=None, arg2=0)
+        MockOperator.partial(task_id="mapped").expand(arg1=[1], arg2=[2, 3])
+
+    assert dag.get_task("simple").execution_timeout == timedelta(minutes=30)
+    assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30)
+
+
+def test_mapped_task_applies_default_args_taskflow(dag_maker):
+    with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) 
as dag:
+
+        @dag.task
+        def simple(arg):
+            pass
+
+        @dag.task
+        def mapped(arg):
+            pass
+
+        simple(arg=0)
+        mapped.expand(arg=[1, 2])
+
+    assert dag.get_task("simple").execution_timeout == timedelta(minutes=30)
+    assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30)
+
+
+def test_mapped_render_template_fields_validating_operator(dag_maker, session):
+    class MyOperator(MockOperator):
+        def __init__(self, value, arg1, **kwargs):
+            assert isinstance(value, str), "value should have been resolved 
before unmapping"
+            assert isinstance(arg1, str), "value should have been resolved 
before unmapping"
+            super().__init__(arg1=arg1, **kwargs)
+            self.value = value
+
+    with dag_maker(session=session):
+        task1 = BaseOperator(task_id="op1")
+        xcom_arg = XComArg(task1)
+        mapped = MyOperator.partial(task_id='a', arg2='{{ ti.task_id 
}}').expand(
+            value=xcom_arg, arg1=xcom_arg
+        )
+
+    dr = dag_maker.create_dagrun()
+    ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
+
+    ti.xcom_push(key=XCOM_RETURN_KEY, value=['{{ ds }}'], session=session)
+
+    session.add(
+        TaskMap(
+            dag_id=dr.dag_id,
+            task_id=task1.task_id,
+            run_id=dr.run_id,
+            map_index=-1,
+            length=1,
+            keys=None,
+        )
+    )
+    session.flush()
+
+    mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, 
session=session)
+    mapped_ti.map_index = 0
+    op = 
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
+    assert isinstance(op, MyOperator)
+
+    assert op.value == "{{ ds }}", "Should not be templated!"
+    assert op.arg1 == "{{ ds }}"
+    assert op.arg2 == "a"

Reply via email to