This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 64a64abd79 Fix dag serialization (#34042)
64a64abd79 is described below

commit 64a64abd793c5e7eb49ef68e8724c9346c7536d4
Author: mhenc <[email protected]>
AuthorDate: Fri Oct 27 10:47:34 2023 +0200

    Fix dag serialization (#34042)
---
 airflow/serialization/pydantic/dag.py          |   2 +-
 airflow/serialization/pydantic/dag_run.py      |   2 +-
 airflow/serialization/serialized_objects.py    |   2 +-
 tests/serialization/test_dag_serialization.py  |   8 +-
 tests/serialization/test_serialized_objects.py | 217 ++++++++++++++++++++++---
 5 files changed, 200 insertions(+), 31 deletions(-)

diff --git a/airflow/serialization/pydantic/dag.py 
b/airflow/serialization/pydantic/dag.py
index b53c5cb190..48ce1f0c56 100644
--- a/airflow/serialization/pydantic/dag.py
+++ b/airflow/serialization/pydantic/dag.py
@@ -60,7 +60,7 @@ class DagModelPydantic(BaseModelPydantic):
     is_paused_at_creation: bool = airflow_conf.getboolean("core", 
"dags_are_paused_at_creation")
     is_paused: bool = is_paused_at_creation
     is_subdag: Optional[bool] = False
-    is_active: bool = False
+    is_active: Optional[bool] = False
     last_parsed_time: Optional[datetime]
     last_pickled: Optional[datetime]
     last_expired: Optional[datetime]
diff --git a/airflow/serialization/pydantic/dag_run.py 
b/airflow/serialization/pydantic/dag_run.py
index d296cbf21a..ffd221e7f5 100644
--- a/airflow/serialization/pydantic/dag_run.py
+++ b/airflow/serialization/pydantic/dag_run.py
@@ -74,7 +74,7 @@ class DagRunPydantic(BaseModelPydantic):
     data_interval_end: Optional[datetime]
     last_scheduling_decision: Optional[datetime]
     dag_hash: Optional[str]
-    updated_at: datetime
+    updated_at: Optional[datetime]
     dag: Optional[PydanticDag]
     consumed_dataset_events: List[DatasetEventPydantic]  # noqa
 
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 1fbed14892..31c0ec6d72 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -447,7 +447,7 @@ class BaseSerialization:
             json_pod = PodGenerator.serialize_pod(var)
             return cls._encode(json_pod, type_=DAT.POD)
         elif isinstance(var, DAG):
-            return SerializedDAG.serialize_dag(var)
+            return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG)
         elif isinstance(var, Resources):
             return var.to_dict()
         elif isinstance(var, MappedOperator):
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 05da33e667..d2a762efa7 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -2525,7 +2525,7 @@ def test_mapped_task_group_serde():
         tg.expand(a=[".", ".."])
 
     ser_dag = SerializedBaseOperator.serialize(dag)
-    assert ser_dag["_task_group"]["children"]["tg"] == (
+    assert ser_dag[Encoding.VAR]["_task_group"]["children"]["tg"] == (
         "taskgroup",
         {
             "_group_id": "tg",
@@ -2549,7 +2549,7 @@ def test_mapped_task_group_serde():
         },
     )
 
-    serde_dag = SerializedDAG.deserialize_dag(ser_dag)
+    serde_dag = SerializedDAG.deserialize_dag(ser_dag[Encoding.VAR])
     serde_tg = serde_dag.task_group.children["tg"]
     assert isinstance(serde_tg, MappedTaskGroup)
     assert serde_tg._expand_input == DictOfListsExpandInput({"a": [".", ".."]})
@@ -2568,7 +2568,7 @@ def test_mapped_task_with_operator_extra_links_property():
     with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
         _DummyOperator.partial(task_id="task").expand(inputs=[1, 2, 3])
     serialized_dag = SerializedBaseOperator.serialize(dag)
-    assert serialized_dag["tasks"][0] == {
+    assert serialized_dag[Encoding.VAR]["tasks"][0] == {
         "task_id": "task",
         "expand_input": {
             "type": "dict-of-lists",
@@ -2589,5 +2589,5 @@ def test_mapped_task_with_operator_extra_links_property():
         "_is_empty": False,
         "_is_mapped": True,
     }
-    deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag)
+    deserialized_dag = 
SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR])
     assert deserialized_dag.task_dict["task"].operator_extra_links == 
[AirflowLink2()]
diff --git a/tests/serialization/test_serialized_objects.py 
b/tests/serialization/test_serialized_objects.py
index 6f0a6f8775..e1ff8bace4 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -18,16 +18,34 @@
 from __future__ import annotations
 
 import json
-from datetime import datetime
+from datetime import datetime, timedelta
 
 import pytest
+from dateutil import relativedelta
+from kubernetes.client import models as k8s
+from pendulum.tz.timezone import Timezone
 
+from airflow.datasets import Dataset
 from airflow.exceptions import SerializationError
-from airflow.models.taskinstance import TaskInstance
+from airflow.jobs.job import Job
+from airflow.models.connection import Connection
+from airflow.models.dag import DAG, DagModel
+from airflow.models.dagrun import DagRun
+from airflow.models.param import Param
+from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
+from airflow.models.xcom_arg import XComArg
 from airflow.operators.empty import EmptyOperator
+from airflow.operators.python import PythonOperator
+from airflow.serialization.enums import DagAttributeTypes as DAT
+from airflow.serialization.pydantic.dag import DagModelPydantic
+from airflow.serialization.pydantic.dag_run import DagRunPydantic
+from airflow.serialization.pydantic.job import JobPydantic
 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
 from airflow.settings import _ENABLE_AIP_44
-from airflow.utils.state import State
+from airflow.utils.operator_resources import Resources
+from airflow.utils.state import DagRunState, State
+from airflow.utils.task_group import TaskGroup
+from airflow.utils.types import DagRunType
 from tests import REPO_ROOT
 
 
@@ -82,31 +100,182 @@ def test_strict_mode():
         BaseSerialization.serialize(obj, strict=True)  # now raises
 
 
+TI = TaskInstance(
+    task=EmptyOperator(task_id="test-task"),
+    run_id="fake_run",
+    state=State.RUNNING,
+)
+
+TI_WITH_START_DAY = TaskInstance(
+    task=EmptyOperator(task_id="test-task"),
+    run_id="fake_run",
+    state=State.RUNNING,
+)
+TI_WITH_START_DAY.start_date = datetime.utcnow()
+
+DAG_RUN = DagRun(
+    dag_id="test_dag_id",
+    run_id="test_dag_run_id",
+    run_type=DagRunType.MANUAL,
+    execution_date=datetime.utcnow(),
+    start_date=datetime.utcnow(),
+    external_trigger=True,
+    state=DagRunState.SUCCESS,
+)
+DAG_RUN.id = 1
+
+
+def equals(a, b) -> bool:
+    return a == b
+
+
+def equal_time(a: datetime, b: datetime) -> bool:
+    return a.strftime("%s") == b.strftime("%s")
+
+
[email protected](
+    "input, encoded_type, cmp_func",
+    [
+        ("test_str", None, equals),
+        (1, None, equals),
+        (datetime.utcnow(), DAT.DATETIME, equal_time),
+        (timedelta(minutes=2), DAT.TIMEDELTA, equals),
+        (Timezone("UTC"), DAT.TIMEZONE, lambda a, b: a.name == b.name),
+        (relativedelta.relativedelta(hours=+1), DAT.RELATIVEDELTA, lambda a, 
b: a.hours == b.hours),
+        ({"test": "dict", "test-1": 1}, None, equals),
+        (["array_item", 2], None, equals),
+        (("tuple_item", 3), DAT.TUPLE, equals),
+        (set(["set_item", 3]), DAT.SET, equals),
+        (
+            k8s.V1Pod(
+                metadata=k8s.V1ObjectMeta(
+                    name="test", annotations={"test": "annotation"}, 
creation_timestamp=datetime.utcnow()
+                )
+            ),
+            DAT.POD,
+            equals,
+        ),
+        (
+            DAG(
+                "fake-dag",
+                schedule="*/10 * * * *",
+                default_args={"depends_on_past": True},
+                start_date=datetime.utcnow(),
+                catchup=False,
+            ),
+            DAT.DAG,
+            lambda a, b: a.dag_id == b.dag_id and equal_time(a.start_date, 
b.start_date),
+        ),
+        (Resources(cpus=0.1, ram=2048), None, None),
+        (EmptyOperator(task_id="test-task"), None, None),
+        (TaskGroup(group_id="test-group", dag=DAG(dag_id="test_dag", 
start_date=datetime.now())), None, None),
+        (
+            Param("test", "desc"),
+            DAT.PARAM,
+            lambda a, b: a.value == b.value and a.description == b.description,
+        ),
+        (
+            XComArg(
+                operator=PythonOperator(
+                    python_callable=int,
+                    task_id="test_xcom_op",
+                    do_xcom_push=True,
+                )
+            ),
+            DAT.XCOM_REF,
+            None,
+        ),
+        (Dataset(uri="test"), DAT.DATASET, equals),
+        (SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals),
+        (
+            Connection(conn_id="TEST_ID", uri="mysql://"),
+            DAT.CONNECTION,
+            lambda a, b: a.get_uri() == b.get_uri(),
+        ),
+    ],
+)
+def test_serialize_deserialize(input, encoded_type, cmp_func):
+    from airflow.serialization.serialized_objects import BaseSerialization
+
+    serialized = BaseSerialization.serialize(input)  # does not raise
+    json.dumps(serialized)  # does not raise
+    if encoded_type is not None:
+        assert serialized["__type"] == encoded_type
+        assert serialized["__var"] is not None
+    if cmp_func is not None:
+        deserialized = BaseSerialization.deserialize(serialized)
+        assert cmp_func(input, deserialized)
+
+    # Verify recursive behavior
+    obj = [[input]]
+    serialized = BaseSerialization.serialize(obj)  # does not raise
+    # Verify the result is JSON-serializable
+    json.dumps(serialized)  # does not raise
+
+
 @pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
-def test_use_pydantic_models():
-    """If use_pydantic_models=True the TaskInstance object should be 
serialized to TaskInstancePydantic."""
[email protected](
+    "input, pydantic_class, encoded_type, cmp_func",
+    [
+        (
+            Job(state=State.RUNNING, latest_heartbeat=datetime.utcnow()),
+            JobPydantic,
+            DAT.BASE_JOB,
+            lambda a, b: equal_time(a.latest_heartbeat, b.latest_heartbeat),
+        ),
+        (
+            TI_WITH_START_DAY,
+            TaskInstancePydantic,
+            DAT.TASK_INSTANCE,
+            lambda a, b: equal_time(a.start_date, b.start_date),
+        ),
+        (
+            DAG_RUN,
+            DagRunPydantic,
+            DAT.DAG_RUN,
+            lambda a, b: equal_time(a.execution_date, b.execution_date)
+            and equal_time(a.start_date, b.start_date),
+        ),
+        # DataSet is already serialized by non-Pydantic serialization. Is 
DatasetPydantic needed then?
+        # (
+        #     Dataset(
+        #         uri="foo://bar",
+        #         extra={"foo": "bar"},
+        #     ),
+        #     DatasetPydantic,
+        #     DAT.DATA_SET,
+        #     lambda a, b: a.uri == b.uri and a.extra == b.extra,
+        # ),
+        (
+            DagModel(
+                dag_id="TEST_DAG_1",
+                fileloc="/tmp/dag_1.py",
+                schedule_interval="2 2 * * *",
+                is_paused=True,
+            ),
+            DagModelPydantic,
+            DAT.DAG_MODEL,
+            lambda a, b: a.fileloc == b.fileloc and a.schedule_interval == 
b.schedule_interval,
+        ),
+    ],
+)
+def test_serialize_deserialize_pydantic(input, pydantic_class, encoded_type, 
cmp_func):
+    """If use_pydantic_models=True the objects should be serialized to 
Pydantic objects."""
 
     from airflow.serialization.serialized_objects import BaseSerialization
 
-    ti = TaskInstance(
-        task=EmptyOperator(task_id="task"),
-        run_id="run_id",
-        state=State.RUNNING,
-    )
-    start_date = datetime.utcnow()
-    ti.start_date = start_date
-    obj = [[ti]]  # nested to verify recursive behavior
-
-    serialized = BaseSerialization.serialize(obj, use_pydantic_models=True)  # 
does not raise
-    deserialized = BaseSerialization.deserialize(serialized, 
use_pydantic_models=True)  # does not raise
-    assert isinstance(deserialized[0][0], TaskInstancePydantic)
-
-    serialized_json = json.dumps(serialized)  # does not raise
-    deserialized_from_json = BaseSerialization.deserialize(
-        json.loads(serialized_json), use_pydantic_models=True
-    )  # does not raise
-    assert isinstance(deserialized_from_json[0][0], TaskInstancePydantic)
-    assert deserialized_from_json[0][0].start_date == start_date
+    serialized = BaseSerialization.serialize(input, use_pydantic_models=True)  
# does not raise
+    # Verify the result is JSON-serializable
+    json.dumps(serialized)  # does not raise
+    assert serialized["__type"] == encoded_type
+    assert serialized["__var"] is not None
+    deserialized = BaseSerialization.deserialize(serialized, 
use_pydantic_models=True)
+    assert isinstance(deserialized, pydantic_class)
+    assert cmp_func(input, deserialized)
+
+    # Verify recursive behavior
+    obj = [[input]]
+    BaseSerialization.serialize(obj, use_pydantic_models=True)  # does not 
raise
 
 
 def test_serialized_mapped_operator_unmap(dag_maker):

Reply via email to