This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 70b41e46b4 Move MappedOperator tests to mirror code location (#23884)
70b41e46b4 is described below
commit 70b41e46b46e65c0446a40ab91624cb2291a5039
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Tue May 24 11:28:54 2022 +0100
Move MappedOperator tests to mirror code location (#23884)
At some point during the development of AIP-42 we moved the code for
MappedOperator out of baseoperator.py to mappedoperator.py, but we
didn't move the tests at the same time
---
tests/models/test_baseoperator.py | 252 --------------------------------
tests/models/test_mappedoperator.py | 278 ++++++++++++++++++++++++++++++++++++
2 files changed, 278 insertions(+), 252 deletions(-)
diff --git a/tests/models/test_baseoperator.py
b/tests/models/test_baseoperator.py
index 5ba271a5a1..8c75c86ed4 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -22,7 +22,6 @@ from typing import Any, NamedTuple
from unittest import mock
import jinja2
-import pendulum
import pytest
from airflow.decorators import task as task_decorator
@@ -30,20 +29,13 @@ from airflow.exceptions import AirflowException
from airflow.lineage.entities import File
from airflow.models import DAG
from airflow.models.baseoperator import BaseOperator, BaseOperatorMeta, chain,
cross_downstream
-from airflow.models.mappedoperator import MappedOperator
-from airflow.models.taskinstance import TaskInstance
-from airflow.models.taskmap import TaskMap
-from airflow.models.xcom import XCOM_RETURN_KEY
-from airflow.models.xcom_arg import XComArg
from airflow.utils.context import Context
from airflow.utils.edgemodifier import Label
-from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule
from tests.models import DEFAULT_DATE
from tests.test_utils.config import conf_vars
-from tests.test_utils.mapping import expand_mapped_task
from tests.test_utils.mock_operators import DeprecatedOperator, MockOperator
@@ -752,250 +744,6 @@ def test_operator_retries(caplog, dag_maker, retries,
expected):
assert caplog.record_tuples == expected
-def test_task_mapping_with_dag():
- with DAG("test-dag", start_date=DEFAULT_DATE) as dag:
- task1 = BaseOperator(task_id="op1")
- literal = ['a', 'b', 'c']
- mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal)
- finish = MockOperator(task_id="finish")
-
- task1 >> mapped >> finish
-
- assert task1.downstream_list == [mapped]
- assert mapped in dag.tasks
- assert mapped.task_group == dag.task_group
- # At parse time there should only be three tasks!
- assert len(dag.tasks) == 3
-
- assert finish.upstream_list == [mapped]
- assert mapped.downstream_list == [finish]
-
-
-def test_task_mapping_without_dag_context():
- with DAG("test-dag", start_date=DEFAULT_DATE) as dag:
- task1 = BaseOperator(task_id="op1")
- literal = ['a', 'b', 'c']
- mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal)
-
- task1 >> mapped
-
- assert isinstance(mapped, MappedOperator)
- assert mapped in dag.tasks
- assert task1.downstream_list == [mapped]
- assert mapped in dag.tasks
- # At parse time there should only be two tasks!
- assert len(dag.tasks) == 2
-
-
-def test_task_mapping_default_args():
- default_args = {'start_date': DEFAULT_DATE.now(), 'owner': 'test'}
- with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args):
- task1 = BaseOperator(task_id="op1")
- literal = ['a', 'b', 'c']
- mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal)
-
- task1 >> mapped
-
- assert mapped.partial_kwargs['owner'] == 'test'
- assert mapped.start_date == pendulum.instance(default_args['start_date'])
-
-
-def test_map_unknown_arg_raises():
- with pytest.raises(TypeError, match=r"argument 'file'"):
- BaseOperator.partial(task_id='a').expand(file=[1, 2, {'a': 'b'}])
-
-
-def test_map_xcom_arg():
- """Test that dependencies are correct when mapping with an XComArg"""
- with DAG("test-dag", start_date=DEFAULT_DATE):
- task1 = BaseOperator(task_id="op1")
- mapped =
MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
- finish = MockOperator(task_id="finish")
-
- mapped >> finish
-
- assert task1.downstream_list == [mapped]
-
-
-def test_partial_on_instance() -> None:
- """`.partial` on an instance should fail -- it's only designed to be
called on classes"""
- with pytest.raises(TypeError):
- MockOperator(
- task_id='a',
- ).partial()
-
-
-def test_partial_on_class() -> None:
- # Test that we accept args for superclasses too
- op = MockOperator.partial(task_id='a', arg1="a",
trigger_rule=TriggerRule.ONE_FAILED)
- assert op.kwargs["arg1"] == "a"
- assert op.kwargs["trigger_rule"] == TriggerRule.ONE_FAILED
-
-
-def test_partial_on_class_invalid_ctor_args() -> None:
- """Test that when we pass invalid args to partial().
-
- I.e. if an arg is not known on the class or any of its parent classes we
error at parse time
- """
- with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"):
- MockOperator.partial(task_id='a', foo='bar', bar=2)
-
-
[email protected](
- ["num_existing_tis", "expected"],
- (
- pytest.param(0, [(0, None), (1, None), (2, None)],
id='only-unmapped-ti-exists'),
- pytest.param(
- 3,
- [(0, 'success'), (1, 'success'), (2, 'success')],
- id='all-tis-exist',
- ),
- pytest.param(
- 5,
- [
- (0, 'success'),
- (1, 'success'),
- (2, 'success'),
- (3, TaskInstanceState.REMOVED),
- (4, TaskInstanceState.REMOVED),
- ],
- id="tis-to-be-removed",
- ),
- ),
-)
-def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis,
expected):
- literal = [1, 2, {'a': 'b'}]
- with dag_maker(session=session):
- task1 = BaseOperator(task_id="op1")
- mapped =
MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
-
- dr = dag_maker.create_dagrun()
-
- session.add(
- TaskMap(
- dag_id=dr.dag_id,
- task_id=task1.task_id,
- run_id=dr.run_id,
- map_index=-1,
- length=len(literal),
- keys=None,
- )
- )
-
- if num_existing_tis:
- # Remove the map_index=-1 TI when we're creating other TIs
- session.query(TaskInstance).filter(
- TaskInstance.dag_id == mapped.dag_id,
- TaskInstance.task_id == mapped.task_id,
- TaskInstance.run_id == dr.run_id,
- ).delete()
-
- for index in range(num_existing_tis):
- # Give the existing TIs a state to make sure we don't change them
- ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index,
state=TaskInstanceState.SUCCESS)
- session.add(ti)
- session.flush()
-
- mapped.expand_mapped_task(dr.run_id, session=session)
-
- indices = (
- session.query(TaskInstance.map_index, TaskInstance.state)
- .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id,
run_id=dr.run_id)
- .order_by(TaskInstance.map_index)
- .all()
- )
-
- assert indices == expected
-
-
-def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session):
- with dag_maker(session=session):
- task1 = BaseOperator(task_id="op1")
- mapped =
MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
-
- dr = dag_maker.create_dagrun()
-
- expand_mapped_task(mapped, dr.run_id, task1.task_id, length=0,
session=session)
-
- indices = (
- session.query(TaskInstance.map_index, TaskInstance.state)
- .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id,
run_id=dr.run_id)
- .order_by(TaskInstance.map_index)
- .all()
- )
-
- assert indices == [(-1, TaskInstanceState.SKIPPED)]
-
-
-def test_mapped_task_applies_default_args_classic(dag_maker):
- with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)})
as dag:
- MockOperator(task_id="simple", arg1=None, arg2=0)
- MockOperator.partial(task_id="mapped").expand(arg1=[1], arg2=[2, 3])
-
- assert dag.get_task("simple").execution_timeout == timedelta(minutes=30)
- assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30)
-
-
-def test_mapped_task_applies_default_args_taskflow(dag_maker):
- with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)})
as dag:
-
- @dag.task
- def simple(arg):
- pass
-
- @dag.task
- def mapped(arg):
- pass
-
- simple(arg=0)
- mapped.expand(arg=[1, 2])
-
- assert dag.get_task("simple").execution_timeout == timedelta(minutes=30)
- assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30)
-
-
-def test_mapped_render_template_fields_validating_operator(dag_maker, session):
- class MyOperator(MockOperator):
- def __init__(self, value, arg1, **kwargs):
- assert isinstance(value, str), "value should have been resolved
before unmapping"
- assert isinstance(arg1, str), "value should have been resolved
before unmapping"
- super().__init__(arg1=arg1, **kwargs)
- self.value = value
-
- with dag_maker(session=session):
- task1 = BaseOperator(task_id="op1")
- xcom_arg = XComArg(task1)
- mapped = MyOperator.partial(task_id='a', arg2='{{ ti.task_id
}}').expand(
- value=xcom_arg, arg1=xcom_arg
- )
-
- dr = dag_maker.create_dagrun()
- ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
-
- ti.xcom_push(key=XCOM_RETURN_KEY, value=['{{ ds }}'], session=session)
-
- session.add(
- TaskMap(
- dag_id=dr.dag_id,
- task_id=task1.task_id,
- run_id=dr.run_id,
- map_index=-1,
- length=1,
- keys=None,
- )
- )
- session.flush()
-
- mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id,
session=session)
- mapped_ti.map_index = 0
- op =
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
- assert isinstance(op, MyOperator)
-
- assert op.value == "{{ ds }}", "Should not be templated!"
- assert op.arg1 == "{{ ds }}"
- assert op.arg2 == "a"
-
-
def test_default_retry_delay(dag_maker):
with dag_maker(dag_id='test_default_retry_delay'):
task1 = BaseOperator(task_id='test_no_explicit_retry_delay')
diff --git a/tests/models/test_mappedoperator.py
b/tests/models/test_mappedoperator.py
new file mode 100644
index 0000000000..c720fd96d9
--- /dev/null
+++ b/tests/models/test_mappedoperator.py
@@ -0,0 +1,278 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from datetime import timedelta
+
+import pendulum
+import pytest
+
+from airflow.models import DAG
+from airflow.models.baseoperator import BaseOperator
+from airflow.models.mappedoperator import MappedOperator
+from airflow.models.taskinstance import TaskInstance
+from airflow.models.taskmap import TaskMap
+from airflow.models.xcom import XCOM_RETURN_KEY
+from airflow.models.xcom_arg import XComArg
+from airflow.utils.state import TaskInstanceState
+from airflow.utils.trigger_rule import TriggerRule
+from tests.models import DEFAULT_DATE
+from tests.test_utils.mapping import expand_mapped_task
+from tests.test_utils.mock_operators import MockOperator
+
+
+def test_task_mapping_with_dag():
+ with DAG("test-dag", start_date=DEFAULT_DATE) as dag:
+ task1 = BaseOperator(task_id="op1")
+ literal = ['a', 'b', 'c']
+ mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal)
+ finish = MockOperator(task_id="finish")
+
+ task1 >> mapped >> finish
+
+ assert task1.downstream_list == [mapped]
+ assert mapped in dag.tasks
+ assert mapped.task_group == dag.task_group
+ # At parse time there should only be three tasks!
+ assert len(dag.tasks) == 3
+
+ assert finish.upstream_list == [mapped]
+ assert mapped.downstream_list == [finish]
+
+
+def test_task_mapping_without_dag_context():
+ with DAG("test-dag", start_date=DEFAULT_DATE) as dag:
+ task1 = BaseOperator(task_id="op1")
+ literal = ['a', 'b', 'c']
+ mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal)
+
+ task1 >> mapped
+
+ assert isinstance(mapped, MappedOperator)
+ assert mapped in dag.tasks
+ assert task1.downstream_list == [mapped]
+ assert mapped in dag.tasks
+ # At parse time there should only be two tasks!
+ assert len(dag.tasks) == 2
+
+
+def test_task_mapping_default_args():
+ default_args = {'start_date': DEFAULT_DATE.now(), 'owner': 'test'}
+ with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args):
+ task1 = BaseOperator(task_id="op1")
+ literal = ['a', 'b', 'c']
+ mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal)
+
+ task1 >> mapped
+
+ assert mapped.partial_kwargs['owner'] == 'test'
+ assert mapped.start_date == pendulum.instance(default_args['start_date'])
+
+
+def test_map_unknown_arg_raises():
+ with pytest.raises(TypeError, match=r"argument 'file'"):
+ BaseOperator.partial(task_id='a').expand(file=[1, 2, {'a': 'b'}])
+
+
+def test_map_xcom_arg():
+ """Test that dependencies are correct when mapping with an XComArg"""
+ with DAG("test-dag", start_date=DEFAULT_DATE):
+ task1 = BaseOperator(task_id="op1")
+ mapped =
MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
+ finish = MockOperator(task_id="finish")
+
+ mapped >> finish
+
+ assert task1.downstream_list == [mapped]
+
+
+def test_partial_on_instance() -> None:
+ """`.partial` on an instance should fail -- it's only designed to be
called on classes"""
+ with pytest.raises(TypeError):
+ MockOperator(
+ task_id='a',
+ ).partial()
+
+
+def test_partial_on_class() -> None:
+ # Test that we accept args for superclasses too
+ op = MockOperator.partial(task_id='a', arg1="a",
trigger_rule=TriggerRule.ONE_FAILED)
+ assert op.kwargs["arg1"] == "a"
+ assert op.kwargs["trigger_rule"] == TriggerRule.ONE_FAILED
+
+
+def test_partial_on_class_invalid_ctor_args() -> None:
+ """Test that when we pass invalid args to partial().
+
+ I.e. if an arg is not known on the class or any of its parent classes we
error at parse time
+ """
+ with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"):
+ MockOperator.partial(task_id='a', foo='bar', bar=2)
+
+
[email protected](
+ ["num_existing_tis", "expected"],
+ (
+ pytest.param(0, [(0, None), (1, None), (2, None)],
id='only-unmapped-ti-exists'),
+ pytest.param(
+ 3,
+ [(0, 'success'), (1, 'success'), (2, 'success')],
+ id='all-tis-exist',
+ ),
+ pytest.param(
+ 5,
+ [
+ (0, 'success'),
+ (1, 'success'),
+ (2, 'success'),
+ (3, TaskInstanceState.REMOVED),
+ (4, TaskInstanceState.REMOVED),
+ ],
+ id="tis-to-be-removed",
+ ),
+ ),
+)
+def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis,
expected):
+ literal = [1, 2, {'a': 'b'}]
+ with dag_maker(session=session):
+ task1 = BaseOperator(task_id="op1")
+ mapped =
MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
+
+ dr = dag_maker.create_dagrun()
+
+ session.add(
+ TaskMap(
+ dag_id=dr.dag_id,
+ task_id=task1.task_id,
+ run_id=dr.run_id,
+ map_index=-1,
+ length=len(literal),
+ keys=None,
+ )
+ )
+
+ if num_existing_tis:
+ # Remove the map_index=-1 TI when we're creating other TIs
+ session.query(TaskInstance).filter(
+ TaskInstance.dag_id == mapped.dag_id,
+ TaskInstance.task_id == mapped.task_id,
+ TaskInstance.run_id == dr.run_id,
+ ).delete()
+
+ for index in range(num_existing_tis):
+ # Give the existing TIs a state to make sure we don't change them
+ ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index,
state=TaskInstanceState.SUCCESS)
+ session.add(ti)
+ session.flush()
+
+ mapped.expand_mapped_task(dr.run_id, session=session)
+
+ indices = (
+ session.query(TaskInstance.map_index, TaskInstance.state)
+ .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id,
run_id=dr.run_id)
+ .order_by(TaskInstance.map_index)
+ .all()
+ )
+
+ assert indices == expected
+
+
+def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session):
+ with dag_maker(session=session):
+ task1 = BaseOperator(task_id="op1")
+ mapped =
MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
+
+ dr = dag_maker.create_dagrun()
+
+ expand_mapped_task(mapped, dr.run_id, task1.task_id, length=0,
session=session)
+
+ indices = (
+ session.query(TaskInstance.map_index, TaskInstance.state)
+ .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id,
run_id=dr.run_id)
+ .order_by(TaskInstance.map_index)
+ .all()
+ )
+
+ assert indices == [(-1, TaskInstanceState.SKIPPED)]
+
+
+def test_mapped_task_applies_default_args_classic(dag_maker):
+ with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)})
as dag:
+ MockOperator(task_id="simple", arg1=None, arg2=0)
+ MockOperator.partial(task_id="mapped").expand(arg1=[1], arg2=[2, 3])
+
+ assert dag.get_task("simple").execution_timeout == timedelta(minutes=30)
+ assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30)
+
+
+def test_mapped_task_applies_default_args_taskflow(dag_maker):
+ with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)})
as dag:
+
+ @dag.task
+ def simple(arg):
+ pass
+
+ @dag.task
+ def mapped(arg):
+ pass
+
+ simple(arg=0)
+ mapped.expand(arg=[1, 2])
+
+ assert dag.get_task("simple").execution_timeout == timedelta(minutes=30)
+ assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30)
+
+
+def test_mapped_render_template_fields_validating_operator(dag_maker, session):
+ class MyOperator(MockOperator):
+ def __init__(self, value, arg1, **kwargs):
+ assert isinstance(value, str), "value should have been resolved
before unmapping"
+ assert isinstance(arg1, str), "value should have been resolved
before unmapping"
+ super().__init__(arg1=arg1, **kwargs)
+ self.value = value
+
+ with dag_maker(session=session):
+ task1 = BaseOperator(task_id="op1")
+ xcom_arg = XComArg(task1)
+ mapped = MyOperator.partial(task_id='a', arg2='{{ ti.task_id
}}').expand(
+ value=xcom_arg, arg1=xcom_arg
+ )
+
+ dr = dag_maker.create_dagrun()
+ ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
+
+ ti.xcom_push(key=XCOM_RETURN_KEY, value=['{{ ds }}'], session=session)
+
+ session.add(
+ TaskMap(
+ dag_id=dr.dag_id,
+ task_id=task1.task_id,
+ run_id=dr.run_id,
+ map_index=-1,
+ length=1,
+ keys=None,
+ )
+ )
+ session.flush()
+
+ mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id,
session=session)
+ mapped_ti.map_index = 0
+ op =
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
+ assert isinstance(op, MyOperator)
+
+ assert op.value == "{{ ds }}", "Should not be templated!"
+ assert op.arg1 == "{{ ds }}"
+ assert op.arg2 == "a"