This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-7-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit faf9de4342d8084800c2feddba0990b7e6b4a652 Author: Tzu-ping Chung <[email protected]> AuthorDate: Fri Aug 4 12:40:43 2023 +0800 Ensure DAG-level references are filled on unmap (#33083) Co-authored-by: Jed Cunningham <[email protected]> (cherry picked from commit bcfadcf6e4b2de587959594f54a9e8fef96c4a2b) --- airflow/models/mappedoperator.py | 2 + airflow/serialization/serialized_objects.py | 59 ++++++++++++++++++-------- tests/serialization/test_serialized_objects.py | 21 +++++++++ 3 files changed, 64 insertions(+), 18 deletions(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 0cf8852ea2..82dcc82aa0 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -659,6 +659,8 @@ class MappedOperator(AbstractOperator): op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) SerializedBaseOperator.populate_operator(op, self.operator_class) + if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies. + SerializedBaseOperator.set_task_dag_references(op, self.dag) return op def _get_specified_expand_input(self) -> ExpandInput: diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 5efa3b3da5..d89f2e22d4 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -735,6 +735,13 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): All operators are casted to SerializedBaseOperator after deserialization. Class specific attributes used by UI are move to object attributes. + + Creating a SerializedBaseOperator is a three-step process: + + 1. Instantiate a :class:`SerializedBaseOperator` object. + 2. Populate attributes with :func:`SerializedBaseOperator.populated_operator`. + 3. When the task's containing DAG is available, fix references to the DAG + with :func:`SerializedBaseOperator.set_task_dag_references`. """ _decorated_fields = {"executor_config"} @@ -875,6 +882,13 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): @classmethod def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: + """Populate operator attributes with serialized values. + + This covers simple attributes that don't reference other things in the + DAG. Setting references (such as ``op.dag`` and task dependencies) is + done in ``set_task_dag_references`` instead, which is called after the + DAG is hydrated. + """ if "label" not in encoded_op: # Handle deserialization of old data before the introduction of TaskGroup encoded_op["label"] = encoded_op["task_id"] @@ -982,6 +996,32 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): # Used to determine if an Operator is inherited from EmptyOperator setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False))) + @staticmethod + def set_task_dag_references(task: Operator, dag: DAG) -> None: + """Handle DAG references on an operator. + + The operator should have been mostly populated earlier by calling + ``populate_operator``. This function further fixes object references + that were not possible before the task's containing DAG is hydrated. + """ + task.dag = dag + + for date_attr in ("start_date", "end_date"): + if getattr(task, date_attr, None) is None: + setattr(task, date_attr, getattr(dag, date_attr, None)) + + if task.subdag is not None: + task.subdag.parent_dag = dag + + # Dereference expand_input and op_kwargs_expand_input. + for k in ("expand_input", "op_kwargs_expand_input"): + if isinstance(kwargs_ref := getattr(task, k, None), _ExpandInputRef): + setattr(task, k, kwargs_ref.deref(dag)) + + for task_id in task.downstream_task_ids: + # Bypass set_upstream etc here - it does more than we want + dag.task_dict[task_id].upstream_task_ids.add(task.task_id) + @classmethod def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: """Deserializes an operator from a JSON object.""" @@ -1328,24 +1368,7 @@ class SerializedDAG(DAG, BaseSerialization): setattr(dag, k, None) for task in dag.task_dict.values(): - task.dag = dag - - for date_attr in ["start_date", "end_date"]: - if getattr(task, date_attr) is None: - setattr(task, date_attr, getattr(dag, date_attr)) - - if task.subdag is not None: - setattr(task.subdag, "parent_dag", dag) - - # Dereference expand_input and op_kwargs_expand_input. - for k in ("expand_input", "op_kwargs_expand_input"): - kwargs_ref = getattr(task, k, None) - if isinstance(kwargs_ref, _ExpandInputRef): - setattr(task, k, kwargs_ref.deref(dag)) - - for task_id in task.downstream_task_ids: - # Bypass set_upstream etc here - it does more than we want - dag.task_dict[task_id].upstream_task_ids.add(task.task_id) + SerializedBaseOperator.set_task_dag_references(task, dag) return dag diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 1eb4214783..17f5187579 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -96,3 +96,24 @@ def test_use_pydantic_models(): deserialized = BaseSerialization.deserialize(serialized, use_pydantic_models=True) # does not raise assert isinstance(deserialized[0][0], TaskInstancePydantic) + + +def test_serialized_mapped_operator_unmap(dag_maker): + from airflow.serialization.serialized_objects import SerializedDAG + from tests.test_utils.mock_operators import MockOperator + + with dag_maker(dag_id="dag") as dag: + MockOperator(task_id="task1", arg1="x") + MockOperator.partial(task_id="task2").expand(arg1=["a", "b"]) + + serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + assert serialized_dag.dag_id == "dag" + + serialized_task1 = serialized_dag.get_task("task1") + assert serialized_task1.dag is serialized_dag + + serialized_task2 = serialized_dag.get_task("task2") + assert serialized_task2.dag is serialized_dag + + serialized_unmapped_task = serialized_task2.unmap(None) + assert serialized_unmapped_task.dag is serialized_dag
