This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3e288abd0b Remove is_mapped attribute (#27881)
3e288abd0b is described below

commit 3e288abd0bc3e5788dcd7f6d9f6bef26ec4c7281
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Nov 25 09:21:01 2022 +0800

    Remove is_mapped attribute (#27881)
---
 .../endpoints/task_instance_endpoint.py            |  3 +-
 airflow/api_connexion/schemas/task_schema.py       | 17 ++--
 airflow/cli/commands/task_command.py               |  3 +-
 airflow/models/baseoperator.py                     |  2 -
 airflow/models/mappedoperator.py                   |  2 -
 airflow/models/operator.py                         | 23 ++++-
 airflow/models/taskinstance.py                     |  5 +-
 airflow/models/xcom_arg.py                         |  3 +-
 airflow/ti_deps/deps/ready_to_reschedule.py        |  4 +-
 airflow/ti_deps/deps/trigger_rule_dep.py           |  3 +-
 airflow/www/views.py                               |  7 +-
 tests/decorators/test_python.py                    |  7 +-
 tests/models/test_taskinstance.py                  | 99 +++++++++++++++++++++-
 13 files changed, 151 insertions(+), 27 deletions(-)

diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py 
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 4e9d6cb9a1..9d5d54ba58 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -45,6 +45,7 @@ from airflow.api_connexion.schemas.task_instance_schema 
import (
 from airflow.api_connexion.types import APIResponse
 from airflow.models import SlaMiss
 from airflow.models.dagrun import DagRun as DR
+from airflow.models.operator import needs_expansion
 from airflow.models.taskinstance import TaskInstance as TI, 
clear_task_instances
 from airflow.security import permissions
 from airflow.utils.airflow_flask_app import get_airflow_app
@@ -202,7 +203,7 @@ def get_mapped_task_instances(
         if not task:
             error_message = f"Task id {task_id} not found"
             raise NotFound(error_message)
-        if not task.is_mapped:
+        if not needs_expansion(task):
             error_message = f"Task id {task_id} is not mapped"
             raise NotFound(error_message)
 
diff --git a/airflow/api_connexion/schemas/task_schema.py 
b/airflow/api_connexion/schemas/task_schema.py
index 0fcb9ff18f..5715ca2ea0 100644
--- a/airflow/api_connexion/schemas/task_schema.py
+++ b/airflow/api_connexion/schemas/task_schema.py
@@ -27,6 +27,7 @@ from airflow.api_connexion.schemas.common_schema import (
     WeightRuleField,
 )
 from airflow.api_connexion.schemas.dag_schema import DAGSchema
+from airflow.models.mappedoperator import MappedOperator
 from airflow.models.operator import Operator
 
 
@@ -59,22 +60,28 @@ class TaskSchema(Schema):
     template_fields = fields.List(fields.String(), dump_only=True)
     sub_dag = fields.Nested(DAGSchema, dump_only=True)
     downstream_task_ids = fields.List(fields.String(), dump_only=True)
-    params = fields.Method("get_params", dump_only=True)
-    is_mapped = fields.Boolean(dump_only=True)
+    params = fields.Method("_get_params", dump_only=True)
+    is_mapped = fields.Method("_get_is_mapped", dump_only=True)
 
-    def _get_class_reference(self, obj):
+    @staticmethod
+    def _get_class_reference(obj):
         result = ClassReferenceSchema().dump(obj)
         return result.data if hasattr(result, "data") else result
 
-    def _get_operator_name(self, obj):
+    @staticmethod
+    def _get_operator_name(obj):
         return obj.operator_name
 
     @staticmethod
-    def get_params(obj):
+    def _get_params(obj):
         """Get the Params defined in a Task."""
         params = obj.params
         return {k: v.dump() for k, v in params.items()}
 
+    @staticmethod
+    def _get_is_mapped(obj):
+        return isinstance(obj, MappedOperator)
+
 
 class TaskCollection(NamedTuple):
     """List of Tasks with metadata."""
diff --git a/airflow/cli/commands/task_command.py 
b/airflow/cli/commands/task_command.py
index a217d2c78d..078565dc38 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -42,6 +42,7 @@ from airflow.models import DagPickle, TaskInstance
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.dag import DAG
 from airflow.models.dagrun import DagRun
+from airflow.models.operator import needs_expansion
 from airflow.ti_deps.dep_context import DepContext
 from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
 from airflow.typing_compat import Literal
@@ -150,7 +151,7 @@ def _get_ti(
     """Get the task instance through DagRun.run_id, if that fails, get the TI 
the old way."""
     if not exec_date_or_run_id and not create_if_necessary:
         raise ValueError("Must provide `exec_date_or_run_id` if not 
`create_if_necessary`.")
-    if task.is_mapped:
+    if needs_expansion(task):
         if map_index < 0:
             raise RuntimeError("No map_index passed to mapped task")
     elif map_index >= 0:
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 9264c92322..a1b49abf66 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -1475,8 +1475,6 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         """Required by DAGNode."""
         return DagAttributeTypes.OP, self.task_id
 
-    is_mapped: ClassVar[bool] = False
-
     @property
     def inherits_from_empty_operator(self):
         """Used to determine if an Operator is inherited from EmptyOperator"""
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 29dc977780..9bdfd932de 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -286,7 +286,6 @@ class MappedOperator(AbstractOperator):
     This should be a name to call ``getattr()`` on.
     """
 
-    is_mapped: ClassVar[bool] = True
     subdag: None = None  # Since we don't support SubDagOperator, this is 
always None.
 
     HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = 
AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
@@ -329,7 +328,6 @@ class MappedOperator(AbstractOperator):
         return frozenset(attr.fields_dict(MappedOperator)) - {
             "dag",
             "deps",
-            "is_mapped",
             "expand_input",  # This is needed to be able to accept XComArg.
             "subdag",
             "task_group",
diff --git a/airflow/models/operator.py b/airflow/models/operator.py
index 9c6493a7fa..7352ecbdcb 100644
--- a/airflow/models/operator.py
+++ b/airflow/models/operator.py
@@ -19,9 +19,30 @@ from __future__ import annotations
 
 from typing import Union
 
+from airflow.models.abstractoperator import AbstractOperator
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.mappedoperator import MappedOperator
+from airflow.typing_compat import TypeGuard
 
 Operator = Union[BaseOperator, MappedOperator]
 
-__all__ = ["Operator"]
+
+def needs_expansion(task: AbstractOperator) -> TypeGuard[Operator]:
+    """Whether a task needs expansion at runtime.
+
+    A task needs expansion if it either
+
+    * Is a mapped operator, or
+    * Is in a mapped task group.
+
+    This is implemented as a free function (instead of a property) so we can
+    make it a type guard.
+    """
+    if isinstance(task, MappedOperator):
+        return True
+    if task.get_closest_mapped_task_group() is not None:
+        return True
+    return False
+
+
+__all__ = ["Operator", "needs_expansion"]
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 8b6197fe85..8f08f5105b 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -85,6 +85,7 @@ from airflow.exceptions import (
 )
 from airflow.models.base import Base, StringID
 from airflow.models.log import Log
+from airflow.models.mappedoperator import MappedOperator
 from airflow.models.param import process_params
 from airflow.models.taskfail import TaskFail
 from airflow.models.taskmap import TaskMap
@@ -2250,7 +2251,7 @@ class TaskInstance(Base, LoggingMixin):
         # currently possible for a downstream to depend on one individual 
mapped
         # task instance. This will change when we implement task mapping inside
         # a mapped task group, and we'll need to further analyze the case.
-        if task.is_mapped:
+        if isinstance(task, MappedOperator):
             return
         if value is None:
             raise XComForMappingNotPushed()
@@ -2679,7 +2680,7 @@ def _find_common_ancestor_mapped_group(node1: Operator, 
node2: Operator) -> Mapp
 
 def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> 
bool:
     """Whether given operator is *further* mapped inside a task group."""
-    if operator.is_mapped:
+    if isinstance(operator, MappedOperator):
         return True
     task_group = operator.task_group
     while task_group is not None and task_group.group_id != container.group_id:
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index fff7ca2112..d2b80474a9 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -26,6 +26,7 @@ from sqlalchemy.orm import Session
 
 from airflow.exceptions import XComNotFound
 from airflow.models.abstractoperator import AbstractOperator
+from airflow.models.mappedoperator import MappedOperator
 from airflow.models.taskmixin import DAGNode, DependencyMixin
 from airflow.models.xcom import XCOM_RETURN_KEY
 from airflow.utils.context import Context
@@ -301,7 +302,7 @@ class PlainXComArg(XComArg):
         from airflow.models.xcom import XCom
 
         task = self.operator
-        if task.is_mapped:
+        if isinstance(task, MappedOperator):
             query = session.query(func.count(XCom.map_index)).filter(
                 XCom.dag_id == task.dag_id,
                 XCom.run_id == run_id,
diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py 
b/airflow/ti_deps/deps/ready_to_reschedule.py
index ca57f6e29d..6ac9f492f0 100644
--- a/airflow/ti_deps/deps/ready_to_reschedule.py
+++ b/airflow/ti_deps/deps/ready_to_reschedule.py
@@ -41,7 +41,9 @@ class ReadyToRescheduleDep(BaseTIDep):
         considered as passed. This dependency fails if the latest reschedule
         request's reschedule date is still in future.
         """
-        is_mapped = ti.task.is_mapped
+        from airflow.models.mappedoperator import MappedOperator
+
+        is_mapped = isinstance(ti.task, MappedOperator)
         if not is_mapped and not getattr(ti.task, "reschedule", False):
             # Mapped sensors don't have the reschedule property (it can only
             # be calculated after unmapping), so we don't check them here.
diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py 
b/airflow/ti_deps/deps/trigger_rule_dep.py
index 29fd479485..d932a6dd21 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -105,6 +105,7 @@ class TriggerRuleDep(BaseTIDep):
         :param dep_context: The current dependency context.
         :param session: Database session.
         """
+        from airflow.models.operator import needs_expansion
         from airflow.models.taskinstance import TaskInstance
 
         task = ti.task
@@ -203,7 +204,7 @@ class TriggerRuleDep(BaseTIDep):
 
         # Optimization: Don't need to hit the database if all upstreams are
         # "simple" tasks (no task or task group mapping involved).
-        if not any(t.is_mapped or t.get_closest_mapped_task_group() for t in 
upstream_tasks.values()):
+        if not any(needs_expansion(t) for t in upstream_tasks.values()):
             upstream = len(upstream_tasks)
         else:
             upstream = (
diff --git a/airflow/www/views.py b/airflow/www/views.py
index e334dcce7b..f3f84bb851 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -99,6 +99,7 @@ from airflow.models.dag import DAG, 
get_dataset_triggered_next_run_info
 from airflow.models.dagcode import DagCode
 from airflow.models.dagrun import DagRun, DagRunType
 from airflow.models.dataset import DagScheduleDatasetReference, 
DatasetDagRunQueue, DatasetEvent, DatasetModel
+from airflow.models.mappedoperator import MappedOperator
 from airflow.models.operator import Operator
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskinstance import TaskInstance, TaskInstanceNote
@@ -340,7 +341,7 @@ def dag_to_grid(dag, dag_runs, session):
                     set_overall_state(record)
                     yield record
 
-            if item.is_mapped:
+            if isinstance(item, MappedOperator):
                 instances = list(_mapped_summary(grouped_tis.get(item.task_id, 
[])))
             else:
                 instances = list(map(_get_summary, 
grouped_tis.get(item.task_id, [])))
@@ -350,7 +351,7 @@ def dag_to_grid(dag, dag_runs, session):
                 "instances": instances,
                 "label": item.label,
                 "extra_links": item.extra_links,
-                "is_mapped": item.is_mapped,
+                "is_mapped": isinstance(item, MappedOperator),
                 "has_outlet_datasets": any(isinstance(i, Dataset) for i in 
(item.outlets or [])),
                 "operator": item.operator_name,
             }
@@ -2848,7 +2849,7 @@ class Airflow(AirflowBaseView):
                 "dag_id": t.dag_id,
                 "task_type": t.task_type,
                 "extra_links": t.extra_links,
-                "is_mapped": t.is_mapped,
+                "is_mapped": isinstance(t, MappedOperator),
                 "trigger_rule": t.trigger_rule,
             }
             for t in dag.tasks
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 3ded6c80fa..47a908db77 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -28,6 +28,7 @@ from airflow.exceptions import AirflowException
 from airflow.models import DAG
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.expandinput import DictOfListsExpandInput
+from airflow.models.mappedoperator import MappedOperator
 from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskmap import TaskMap
 from airflow.models.xcom import XCOM_RETURN_KEY
@@ -716,7 +717,7 @@ def 
test_mapped_decorator_converts_partial_kwargs(dag_maker, session):
     assert [ti.task_id for ti in dec.schedulable_tis] == ["task1", "task1"]
     for ti in dec.schedulable_tis:
         ti.run(session=session)
-        assert not ti.task.is_mapped
+        assert not isinstance(ti.task, MappedOperator)
         assert ti.task.retry_delay == timedelta(seconds=300)  # Operator 
default.
 
     # Expand task2.
@@ -756,9 +757,9 @@ def test_mapped_render_template_fields(dag_maker, session):
     mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, 
session=session)
     mapped_ti.map_index = 0
 
-    assert mapped_ti.task.is_mapped
+    assert isinstance(mapped_ti.task, MappedOperator)
     
mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
-    assert not mapped_ti.task.is_mapped
+    assert isinstance(mapped_ti.task, BaseOperator)
 
     assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}"
     assert mapped_ti.task.op_kwargs["arg2"] == "fn"
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 8f4fb47881..dd542666cd 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -35,7 +35,7 @@ import pytest
 from freezegun import freeze_time
 
 from airflow import models, settings
-from airflow.decorators import task
+from airflow.decorators import task, task_group
 from airflow.example_dags.plugins.workday import AfterWorkdayTimetable
 from airflow.exceptions import (
     AirflowException,
@@ -3002,24 +3002,39 @@ class TestTaskInstanceRecordTaskMapXComPush:
 
         assert dag_maker.session.query(TaskMap).count() == 0
 
-    @pytest.mark.parametrize("xcom_value", [[1, 2, 3], {"a": 1, "b": 2}, 
"abc"])
-    def test_not_recorded_if_irrelevant(self, dag_maker, xcom_value):
+    @pytest.mark.parametrize("xcom_1", [[1, 2, 3], {"a": 1, "b": 2}, "abc"])
+    @pytest.mark.parametrize("xcom_4", [[1, 2, 3], {"a": 1, "b": 2}])
+    def test_not_recorded_if_irrelevant(self, dag_maker, xcom_1, xcom_4):
         """Return value should only be recorded if a mapped downstream uses 
the it."""
         with dag_maker(dag_id="test_not_recorded_for_unused") as dag:
 
             @dag.task()
             def push_1():
-                return xcom_value
+                return xcom_1
 
             @dag.task()
             def push_2():
                 return [-1, -2]
 
+            @dag.task()
+            def push_3():
+                return ["x", "y"]
+
+            @dag.task()
+            def push_4():
+                return xcom_4
+
             @dag.task()
             def show(arg1, arg2):
                 print(arg1, arg2)
 
+            @task_group()
+            def tg(arg):
+                show(arg1=task_3, arg2=arg)
+
+            task_3 = push_3()
             show.partial(arg1=push_1()).expand(arg2=push_2())
+            tg.expand(arg=push_4())
 
         tis = {ti.task_id: ti for ti in 
dag_maker.create_dagrun().task_instances}
 
@@ -3029,6 +3044,12 @@ class TestTaskInstanceRecordTaskMapXComPush:
         tis["push_2"].run()
         assert dag_maker.session.query(TaskMap).count() == 1
 
+        tis["push_3"].run()
+        assert dag_maker.session.query(TaskMap).count() == 1
+
+        tis["push_4"].run()
+        assert dag_maker.session.query(TaskMap).count() == 2
+
     @pytest.mark.parametrize(
         "return_value, exception_type, error_message",
         [
@@ -3089,6 +3110,76 @@ class TestTaskInstanceRecordTaskMapXComPush:
         assert ti.state == TaskInstanceState.FAILED
         assert str(ctx.value) == error_message
 
+    @pytest.mark.parametrize(
+        "return_value, exception_type, error_message",
+        [
+            (123, UnmappableXComTypePushed, "unmappable return type 'int'"),
+            (None, XComForMappingNotPushed, "did not push XCom for task 
mapping"),
+        ],
+    )
+    def test_task_group_expand_error_if_unmappable_type(
+        self,
+        dag_maker,
+        return_value,
+        exception_type,
+        error_message,
+    ):
+        """If an unmappable return value is used , fail the task that pushed 
the XCom."""
+        with 
dag_maker(dag_id="test_task_group_expand_error_if_unmappable_type") as dag:
+
+            @dag.task()
+            def push():
+                return return_value
+
+            @task_group
+            def tg(arg):
+                MockOperator(task_id="pull", arg1=arg)
+
+            tg.expand(arg=push())
+
+        ti = next(ti for ti in dag_maker.create_dagrun().task_instances if 
ti.task_id == "push")
+        with pytest.raises(exception_type) as ctx:
+            ti.run()
+
+        assert dag_maker.session.query(TaskMap).count() == 0
+        assert ti.state == TaskInstanceState.FAILED
+        assert str(ctx.value) == error_message
+
+    @pytest.mark.parametrize(
+        "return_value, exception_type, error_message",
+        [
+            (123, UnmappableXComTypePushed, "unmappable return type 'int'"),
+            (None, XComForMappingNotPushed, "did not push XCom for task 
mapping"),
+        ],
+    )
+    def test_task_group_expand_kwargs_error_if_unmappable_type(
+        self,
+        dag_maker,
+        return_value,
+        exception_type,
+        error_message,
+    ):
+        """If an unmappable return value is used, fail the task that pushed 
the XCom."""
+        with 
dag_maker(dag_id="test_task_group_expand_kwargs_error_if_unmappable_type") as 
dag:
+
+            @dag.task()
+            def push():
+                return return_value
+
+            @task_group
+            def tg(arg):
+                MockOperator(task_id="pull", arg1=arg)
+
+            tg.expand_kwargs(push())
+
+        ti = next(ti for ti in dag_maker.create_dagrun().task_instances if 
ti.task_id == "push")
+        with pytest.raises(exception_type) as ctx:
+            ti.run()
+
+        assert dag_maker.session.query(TaskMap).count() == 0
+        assert ti.state == TaskInstanceState.FAILED
+        assert str(ctx.value) == error_message
+
     @pytest.mark.parametrize(
         "create_upstream",
         [

Reply via email to