This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3e288abd0b Remove is_mapped attribute (#27881)
3e288abd0b is described below
commit 3e288abd0bc3e5788dcd7f6d9f6bef26ec4c7281
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Nov 25 09:21:01 2022 +0800
Remove is_mapped attribute (#27881)
---
.../endpoints/task_instance_endpoint.py | 3 +-
airflow/api_connexion/schemas/task_schema.py | 17 ++--
airflow/cli/commands/task_command.py | 3 +-
airflow/models/baseoperator.py | 2 -
airflow/models/mappedoperator.py | 2 -
airflow/models/operator.py | 23 ++++-
airflow/models/taskinstance.py | 5 +-
airflow/models/xcom_arg.py | 3 +-
airflow/ti_deps/deps/ready_to_reschedule.py | 4 +-
airflow/ti_deps/deps/trigger_rule_dep.py | 3 +-
airflow/www/views.py | 7 +-
tests/decorators/test_python.py | 7 +-
tests/models/test_taskinstance.py | 99 +++++++++++++++++++++-
13 files changed, 151 insertions(+), 27 deletions(-)
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 4e9d6cb9a1..9d5d54ba58 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -45,6 +45,7 @@ from airflow.api_connexion.schemas.task_instance_schema
import (
from airflow.api_connexion.types import APIResponse
from airflow.models import SlaMiss
from airflow.models.dagrun import DagRun as DR
+from airflow.models.operator import needs_expansion
from airflow.models.taskinstance import TaskInstance as TI,
clear_task_instances
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
@@ -202,7 +203,7 @@ def get_mapped_task_instances(
if not task:
error_message = f"Task id {task_id} not found"
raise NotFound(error_message)
- if not task.is_mapped:
+ if not needs_expansion(task):
error_message = f"Task id {task_id} is not mapped"
raise NotFound(error_message)
diff --git a/airflow/api_connexion/schemas/task_schema.py
b/airflow/api_connexion/schemas/task_schema.py
index 0fcb9ff18f..5715ca2ea0 100644
--- a/airflow/api_connexion/schemas/task_schema.py
+++ b/airflow/api_connexion/schemas/task_schema.py
@@ -27,6 +27,7 @@ from airflow.api_connexion.schemas.common_schema import (
WeightRuleField,
)
from airflow.api_connexion.schemas.dag_schema import DAGSchema
+from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
@@ -59,22 +60,28 @@ class TaskSchema(Schema):
template_fields = fields.List(fields.String(), dump_only=True)
sub_dag = fields.Nested(DAGSchema, dump_only=True)
downstream_task_ids = fields.List(fields.String(), dump_only=True)
- params = fields.Method("get_params", dump_only=True)
- is_mapped = fields.Boolean(dump_only=True)
+ params = fields.Method("_get_params", dump_only=True)
+ is_mapped = fields.Method("_get_is_mapped", dump_only=True)
- def _get_class_reference(self, obj):
+ @staticmethod
+ def _get_class_reference(obj):
result = ClassReferenceSchema().dump(obj)
return result.data if hasattr(result, "data") else result
- def _get_operator_name(self, obj):
+ @staticmethod
+ def _get_operator_name(obj):
return obj.operator_name
@staticmethod
- def get_params(obj):
+ def _get_params(obj):
"""Get the Params defined in a Task."""
params = obj.params
return {k: v.dump() for k, v in params.items()}
+ @staticmethod
+ def _get_is_mapped(obj):
+ return isinstance(obj, MappedOperator)
+
class TaskCollection(NamedTuple):
"""List of Tasks with metadata."""
diff --git a/airflow/cli/commands/task_command.py
b/airflow/cli/commands/task_command.py
index a217d2c78d..078565dc38 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -42,6 +42,7 @@ from airflow.models import DagPickle, TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
+from airflow.models.operator import needs_expansion
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
from airflow.typing_compat import Literal
@@ -150,7 +151,7 @@ def _get_ti(
"""Get the task instance through DagRun.run_id, if that fails, get the TI
the old way."""
if not exec_date_or_run_id and not create_if_necessary:
raise ValueError("Must provide `exec_date_or_run_id` if not
`create_if_necessary`.")
- if task.is_mapped:
+ if needs_expansion(task):
if map_index < 0:
raise RuntimeError("No map_index passed to mapped task")
elif map_index >= 0:
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 9264c92322..a1b49abf66 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -1475,8 +1475,6 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
"""Required by DAGNode."""
return DagAttributeTypes.OP, self.task_id
- is_mapped: ClassVar[bool] = False
-
@property
def inherits_from_empty_operator(self):
"""Used to determine if an Operator is inherited from EmptyOperator"""
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 29dc977780..9bdfd932de 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -286,7 +286,6 @@ class MappedOperator(AbstractOperator):
This should be a name to call ``getattr()`` on.
"""
- is_mapped: ClassVar[bool] = True
subdag: None = None # Since we don't support SubDagOperator, this is
always None.
HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] =
AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
@@ -329,7 +328,6 @@ class MappedOperator(AbstractOperator):
return frozenset(attr.fields_dict(MappedOperator)) - {
"dag",
"deps",
- "is_mapped",
"expand_input", # This is needed to be able to accept XComArg.
"subdag",
"task_group",
diff --git a/airflow/models/operator.py b/airflow/models/operator.py
index 9c6493a7fa..7352ecbdcb 100644
--- a/airflow/models/operator.py
+++ b/airflow/models/operator.py
@@ -19,9 +19,30 @@ from __future__ import annotations
from typing import Union
+from airflow.models.abstractoperator import AbstractOperator
from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator
+from airflow.typing_compat import TypeGuard
Operator = Union[BaseOperator, MappedOperator]
-__all__ = ["Operator"]
+
+def needs_expansion(task: AbstractOperator) -> TypeGuard[Operator]:
+ """Whether a task needs expansion at runtime.
+
+ A task needs expansion if it either
+
+ * Is a mapped operator, or
+ * Is in a mapped task group.
+
+ This is implemented as a free function (instead of a property) so we can
+ make it a type guard.
+ """
+ if isinstance(task, MappedOperator):
+ return True
+ if task.get_closest_mapped_task_group() is not None:
+ return True
+ return False
+
+
+__all__ = ["Operator", "needs_expansion"]
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 8b6197fe85..8f08f5105b 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -85,6 +85,7 @@ from airflow.exceptions import (
)
from airflow.models.base import Base, StringID
from airflow.models.log import Log
+from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import process_params
from airflow.models.taskfail import TaskFail
from airflow.models.taskmap import TaskMap
@@ -2250,7 +2251,7 @@ class TaskInstance(Base, LoggingMixin):
# currently possible for a downstream to depend on one individual
mapped
# task instance. This will change when we implement task mapping inside
# a mapped task group, and we'll need to further analyze the case.
- if task.is_mapped:
+ if isinstance(task, MappedOperator):
return
if value is None:
raise XComForMappingNotPushed()
@@ -2679,7 +2680,7 @@ def _find_common_ancestor_mapped_group(node1: Operator,
node2: Operator) -> Mapp
def _is_further_mapped_inside(operator: Operator, container: TaskGroup) ->
bool:
"""Whether given operator is *further* mapped inside a task group."""
- if operator.is_mapped:
+ if isinstance(operator, MappedOperator):
return True
task_group = operator.task_group
while task_group is not None and task_group.group_id != container.group_id:
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index fff7ca2112..d2b80474a9 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -26,6 +26,7 @@ from sqlalchemy.orm import Session
from airflow.exceptions import XComNotFound
from airflow.models.abstractoperator import AbstractOperator
+from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.utils.context import Context
@@ -301,7 +302,7 @@ class PlainXComArg(XComArg):
from airflow.models.xcom import XCom
task = self.operator
- if task.is_mapped:
+ if isinstance(task, MappedOperator):
query = session.query(func.count(XCom.map_index)).filter(
XCom.dag_id == task.dag_id,
XCom.run_id == run_id,
diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py
b/airflow/ti_deps/deps/ready_to_reschedule.py
index ca57f6e29d..6ac9f492f0 100644
--- a/airflow/ti_deps/deps/ready_to_reschedule.py
+++ b/airflow/ti_deps/deps/ready_to_reschedule.py
@@ -41,7 +41,9 @@ class ReadyToRescheduleDep(BaseTIDep):
considered as passed. This dependency fails if the latest reschedule
request's reschedule date is still in future.
"""
- is_mapped = ti.task.is_mapped
+ from airflow.models.mappedoperator import MappedOperator
+
+ is_mapped = isinstance(ti.task, MappedOperator)
if not is_mapped and not getattr(ti.task, "reschedule", False):
# Mapped sensors don't have the reschedule property (it can only
# be calculated after unmapping), so we don't check them here.
diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py
b/airflow/ti_deps/deps/trigger_rule_dep.py
index 29fd479485..d932a6dd21 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -105,6 +105,7 @@ class TriggerRuleDep(BaseTIDep):
:param dep_context: The current dependency context.
:param session: Database session.
"""
+ from airflow.models.operator import needs_expansion
from airflow.models.taskinstance import TaskInstance
task = ti.task
@@ -203,7 +204,7 @@ class TriggerRuleDep(BaseTIDep):
# Optimization: Don't need to hit the database if all upstreams are
# "simple" tasks (no task or task group mapping involved).
- if not any(t.is_mapped or t.get_closest_mapped_task_group() for t in
upstream_tasks.values()):
+ if not any(needs_expansion(t) for t in upstream_tasks.values()):
upstream = len(upstream_tasks)
else:
upstream = (
diff --git a/airflow/www/views.py b/airflow/www/views.py
index e334dcce7b..f3f84bb851 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -99,6 +99,7 @@ from airflow.models.dag import DAG,
get_dataset_triggered_next_run_info
from airflow.models.dagcode import DagCode
from airflow.models.dagrun import DagRun, DagRunType
from airflow.models.dataset import DagScheduleDatasetReference,
DatasetDagRunQueue, DatasetEvent, DatasetModel
+from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance, TaskInstanceNote
@@ -340,7 +341,7 @@ def dag_to_grid(dag, dag_runs, session):
set_overall_state(record)
yield record
- if item.is_mapped:
+ if isinstance(item, MappedOperator):
instances = list(_mapped_summary(grouped_tis.get(item.task_id,
[])))
else:
instances = list(map(_get_summary,
grouped_tis.get(item.task_id, [])))
@@ -350,7 +351,7 @@ def dag_to_grid(dag, dag_runs, session):
"instances": instances,
"label": item.label,
"extra_links": item.extra_links,
- "is_mapped": item.is_mapped,
+ "is_mapped": isinstance(item, MappedOperator),
"has_outlet_datasets": any(isinstance(i, Dataset) for i in
(item.outlets or [])),
"operator": item.operator_name,
}
@@ -2848,7 +2849,7 @@ class Airflow(AirflowBaseView):
"dag_id": t.dag_id,
"task_type": t.task_type,
"extra_links": t.extra_links,
- "is_mapped": t.is_mapped,
+ "is_mapped": isinstance(t, MappedOperator),
"trigger_rule": t.trigger_rule,
}
for t in dag.tasks
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 3ded6c80fa..47a908db77 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -28,6 +28,7 @@ from airflow.exceptions import AirflowException
from airflow.models import DAG
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import DictOfListsExpandInput
+from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCOM_RETURN_KEY
@@ -716,7 +717,7 @@ def
test_mapped_decorator_converts_partial_kwargs(dag_maker, session):
assert [ti.task_id for ti in dec.schedulable_tis] == ["task1", "task1"]
for ti in dec.schedulable_tis:
ti.run(session=session)
- assert not ti.task.is_mapped
+ assert not isinstance(ti.task, MappedOperator)
assert ti.task.retry_delay == timedelta(seconds=300) # Operator
default.
# Expand task2.
@@ -756,9 +757,9 @@ def test_mapped_render_template_fields(dag_maker, session):
mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id,
session=session)
mapped_ti.map_index = 0
- assert mapped_ti.task.is_mapped
+ assert isinstance(mapped_ti.task, MappedOperator)
mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
- assert not mapped_ti.task.is_mapped
+ assert isinstance(mapped_ti.task, BaseOperator)
assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}"
assert mapped_ti.task.op_kwargs["arg2"] == "fn"
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index 8f4fb47881..dd542666cd 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -35,7 +35,7 @@ import pytest
from freezegun import freeze_time
from airflow import models, settings
-from airflow.decorators import task
+from airflow.decorators import task, task_group
from airflow.example_dags.plugins.workday import AfterWorkdayTimetable
from airflow.exceptions import (
AirflowException,
@@ -3002,24 +3002,39 @@ class TestTaskInstanceRecordTaskMapXComPush:
assert dag_maker.session.query(TaskMap).count() == 0
- @pytest.mark.parametrize("xcom_value", [[1, 2, 3], {"a": 1, "b": 2},
"abc"])
- def test_not_recorded_if_irrelevant(self, dag_maker, xcom_value):
+ @pytest.mark.parametrize("xcom_1", [[1, 2, 3], {"a": 1, "b": 2}, "abc"])
+ @pytest.mark.parametrize("xcom_4", [[1, 2, 3], {"a": 1, "b": 2}])
+ def test_not_recorded_if_irrelevant(self, dag_maker, xcom_1, xcom_4):
"""Return value should only be recorded if a mapped downstream uses
the it."""
with dag_maker(dag_id="test_not_recorded_for_unused") as dag:
@dag.task()
def push_1():
- return xcom_value
+ return xcom_1
@dag.task()
def push_2():
return [-1, -2]
+ @dag.task()
+ def push_3():
+ return ["x", "y"]
+
+ @dag.task()
+ def push_4():
+ return xcom_4
+
@dag.task()
def show(arg1, arg2):
print(arg1, arg2)
+ @task_group()
+ def tg(arg):
+ show(arg1=task_3, arg2=arg)
+
+ task_3 = push_3()
show.partial(arg1=push_1()).expand(arg2=push_2())
+ tg.expand(arg=push_4())
tis = {ti.task_id: ti for ti in
dag_maker.create_dagrun().task_instances}
@@ -3029,6 +3044,12 @@ class TestTaskInstanceRecordTaskMapXComPush:
tis["push_2"].run()
assert dag_maker.session.query(TaskMap).count() == 1
+ tis["push_3"].run()
+ assert dag_maker.session.query(TaskMap).count() == 1
+
+ tis["push_4"].run()
+ assert dag_maker.session.query(TaskMap).count() == 2
+
@pytest.mark.parametrize(
"return_value, exception_type, error_message",
[
@@ -3089,6 +3110,76 @@ class TestTaskInstanceRecordTaskMapXComPush:
assert ti.state == TaskInstanceState.FAILED
assert str(ctx.value) == error_message
+ @pytest.mark.parametrize(
+ "return_value, exception_type, error_message",
+ [
+ (123, UnmappableXComTypePushed, "unmappable return type 'int'"),
+ (None, XComForMappingNotPushed, "did not push XCom for task
mapping"),
+ ],
+ )
+ def test_task_group_expand_error_if_unmappable_type(
+ self,
+ dag_maker,
+ return_value,
+ exception_type,
+ error_message,
+ ):
+ """If an unmappable return value is used , fail the task that pushed
the XCom."""
+ with
dag_maker(dag_id="test_task_group_expand_error_if_unmappable_type") as dag:
+
+ @dag.task()
+ def push():
+ return return_value
+
+ @task_group
+ def tg(arg):
+ MockOperator(task_id="pull", arg1=arg)
+
+ tg.expand(arg=push())
+
+ ti = next(ti for ti in dag_maker.create_dagrun().task_instances if
ti.task_id == "push")
+ with pytest.raises(exception_type) as ctx:
+ ti.run()
+
+ assert dag_maker.session.query(TaskMap).count() == 0
+ assert ti.state == TaskInstanceState.FAILED
+ assert str(ctx.value) == error_message
+
+ @pytest.mark.parametrize(
+ "return_value, exception_type, error_message",
+ [
+ (123, UnmappableXComTypePushed, "unmappable return type 'int'"),
+ (None, XComForMappingNotPushed, "did not push XCom for task
mapping"),
+ ],
+ )
+ def test_task_group_expand_kwargs_error_if_unmappable_type(
+ self,
+ dag_maker,
+ return_value,
+ exception_type,
+ error_message,
+ ):
+ """If an unmappable return value is used, fail the task that pushed
the XCom."""
+ with
dag_maker(dag_id="test_task_group_expand_kwargs_error_if_unmappable_type") as
dag:
+
+ @dag.task()
+ def push():
+ return return_value
+
+ @task_group
+ def tg(arg):
+ MockOperator(task_id="pull", arg1=arg)
+
+ tg.expand_kwargs(push())
+
+ ti = next(ti for ti in dag_maker.create_dagrun().task_instances if
ti.task_id == "push")
+ with pytest.raises(exception_type) as ctx:
+ ti.run()
+
+ assert dag_maker.session.query(TaskMap).count() == 0
+ assert ti.state == TaskInstanceState.FAILED
+ assert str(ctx.value) == error_message
+
@pytest.mark.parametrize(
"create_upstream",
[