This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 64a64abd79 Fix dag serialization (#34042)
64a64abd79 is described below
commit 64a64abd793c5e7eb49ef68e8724c9346c7536d4
Author: mhenc <[email protected]>
AuthorDate: Fri Oct 27 10:47:34 2023 +0200
Fix dag serialization (#34042)
---
airflow/serialization/pydantic/dag.py | 2 +-
airflow/serialization/pydantic/dag_run.py | 2 +-
airflow/serialization/serialized_objects.py | 2 +-
tests/serialization/test_dag_serialization.py | 8 +-
tests/serialization/test_serialized_objects.py | 217 ++++++++++++++++++++++---
5 files changed, 200 insertions(+), 31 deletions(-)
diff --git a/airflow/serialization/pydantic/dag.py
b/airflow/serialization/pydantic/dag.py
index b53c5cb190..48ce1f0c56 100644
--- a/airflow/serialization/pydantic/dag.py
+++ b/airflow/serialization/pydantic/dag.py
@@ -60,7 +60,7 @@ class DagModelPydantic(BaseModelPydantic):
is_paused_at_creation: bool = airflow_conf.getboolean("core",
"dags_are_paused_at_creation")
is_paused: bool = is_paused_at_creation
is_subdag: Optional[bool] = False
- is_active: bool = False
+ is_active: Optional[bool] = False
last_parsed_time: Optional[datetime]
last_pickled: Optional[datetime]
last_expired: Optional[datetime]
diff --git a/airflow/serialization/pydantic/dag_run.py
b/airflow/serialization/pydantic/dag_run.py
index d296cbf21a..ffd221e7f5 100644
--- a/airflow/serialization/pydantic/dag_run.py
+++ b/airflow/serialization/pydantic/dag_run.py
@@ -74,7 +74,7 @@ class DagRunPydantic(BaseModelPydantic):
data_interval_end: Optional[datetime]
last_scheduling_decision: Optional[datetime]
dag_hash: Optional[str]
- updated_at: datetime
+ updated_at: Optional[datetime]
dag: Optional[PydanticDag]
consumed_dataset_events: List[DatasetEventPydantic] # noqa
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 1fbed14892..31c0ec6d72 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -447,7 +447,7 @@ class BaseSerialization:
json_pod = PodGenerator.serialize_pod(var)
return cls._encode(json_pod, type_=DAT.POD)
elif isinstance(var, DAG):
- return SerializedDAG.serialize_dag(var)
+ return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG)
elif isinstance(var, Resources):
return var.to_dict()
elif isinstance(var, MappedOperator):
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 05da33e667..d2a762efa7 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -2525,7 +2525,7 @@ def test_mapped_task_group_serde():
tg.expand(a=[".", ".."])
ser_dag = SerializedBaseOperator.serialize(dag)
- assert ser_dag["_task_group"]["children"]["tg"] == (
+ assert ser_dag[Encoding.VAR]["_task_group"]["children"]["tg"] == (
"taskgroup",
{
"_group_id": "tg",
@@ -2549,7 +2549,7 @@ def test_mapped_task_group_serde():
},
)
- serde_dag = SerializedDAG.deserialize_dag(ser_dag)
+ serde_dag = SerializedDAG.deserialize_dag(ser_dag[Encoding.VAR])
serde_tg = serde_dag.task_group.children["tg"]
assert isinstance(serde_tg, MappedTaskGroup)
assert serde_tg._expand_input == DictOfListsExpandInput({"a": [".", ".."]})
@@ -2568,7 +2568,7 @@ def test_mapped_task_with_operator_extra_links_property():
with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
_DummyOperator.partial(task_id="task").expand(inputs=[1, 2, 3])
serialized_dag = SerializedBaseOperator.serialize(dag)
- assert serialized_dag["tasks"][0] == {
+ assert serialized_dag[Encoding.VAR]["tasks"][0] == {
"task_id": "task",
"expand_input": {
"type": "dict-of-lists",
@@ -2589,5 +2589,5 @@ def test_mapped_task_with_operator_extra_links_property():
"_is_empty": False,
"_is_mapped": True,
}
- deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag)
+ deserialized_dag =
SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR])
assert deserialized_dag.task_dict["task"].operator_extra_links ==
[AirflowLink2()]
diff --git a/tests/serialization/test_serialized_objects.py
b/tests/serialization/test_serialized_objects.py
index 6f0a6f8775..e1ff8bace4 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -18,16 +18,34 @@
from __future__ import annotations
import json
-from datetime import datetime
+from datetime import datetime, timedelta
import pytest
+from dateutil import relativedelta
+from kubernetes.client import models as k8s
+from pendulum.tz.timezone import Timezone
+from airflow.datasets import Dataset
from airflow.exceptions import SerializationError
-from airflow.models.taskinstance import TaskInstance
+from airflow.jobs.job import Job
+from airflow.models.connection import Connection
+from airflow.models.dag import DAG, DagModel
+from airflow.models.dagrun import DagRun
+from airflow.models.param import Param
+from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
+from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
+from airflow.operators.python import PythonOperator
+from airflow.serialization.enums import DagAttributeTypes as DAT
+from airflow.serialization.pydantic.dag import DagModelPydantic
+from airflow.serialization.pydantic.dag_run import DagRunPydantic
+from airflow.serialization.pydantic.job import JobPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.settings import _ENABLE_AIP_44
-from airflow.utils.state import State
+from airflow.utils.operator_resources import Resources
+from airflow.utils.state import DagRunState, State
+from airflow.utils.task_group import TaskGroup
+from airflow.utils.types import DagRunType
from tests import REPO_ROOT
@@ -82,31 +100,182 @@ def test_strict_mode():
BaseSerialization.serialize(obj, strict=True) # now raises
+TI = TaskInstance(
+ task=EmptyOperator(task_id="test-task"),
+ run_id="fake_run",
+ state=State.RUNNING,
+)
+
+TI_WITH_START_DAY = TaskInstance(
+ task=EmptyOperator(task_id="test-task"),
+ run_id="fake_run",
+ state=State.RUNNING,
+)
+TI_WITH_START_DAY.start_date = datetime.utcnow()
+
+DAG_RUN = DagRun(
+ dag_id="test_dag_id",
+ run_id="test_dag_run_id",
+ run_type=DagRunType.MANUAL,
+ execution_date=datetime.utcnow(),
+ start_date=datetime.utcnow(),
+ external_trigger=True,
+ state=DagRunState.SUCCESS,
+)
+DAG_RUN.id = 1
+
+
+def equals(a, b) -> bool:
+ return a == b
+
+
+def equal_time(a: datetime, b: datetime) -> bool:
+ return a.strftime("%s") == b.strftime("%s")
+
+
[email protected](
+ "input, encoded_type, cmp_func",
+ [
+ ("test_str", None, equals),
+ (1, None, equals),
+ (datetime.utcnow(), DAT.DATETIME, equal_time),
+ (timedelta(minutes=2), DAT.TIMEDELTA, equals),
+ (Timezone("UTC"), DAT.TIMEZONE, lambda a, b: a.name == b.name),
+ (relativedelta.relativedelta(hours=+1), DAT.RELATIVEDELTA, lambda a,
b: a.hours == b.hours),
+ ({"test": "dict", "test-1": 1}, None, equals),
+ (["array_item", 2], None, equals),
+ (("tuple_item", 3), DAT.TUPLE, equals),
+ (set(["set_item", 3]), DAT.SET, equals),
+ (
+ k8s.V1Pod(
+ metadata=k8s.V1ObjectMeta(
+ name="test", annotations={"test": "annotation"},
creation_timestamp=datetime.utcnow()
+ )
+ ),
+ DAT.POD,
+ equals,
+ ),
+ (
+ DAG(
+ "fake-dag",
+ schedule="*/10 * * * *",
+ default_args={"depends_on_past": True},
+ start_date=datetime.utcnow(),
+ catchup=False,
+ ),
+ DAT.DAG,
+ lambda a, b: a.dag_id == b.dag_id and equal_time(a.start_date,
b.start_date),
+ ),
+ (Resources(cpus=0.1, ram=2048), None, None),
+ (EmptyOperator(task_id="test-task"), None, None),
+ (TaskGroup(group_id="test-group", dag=DAG(dag_id="test_dag",
start_date=datetime.now())), None, None),
+ (
+ Param("test", "desc"),
+ DAT.PARAM,
+ lambda a, b: a.value == b.value and a.description == b.description,
+ ),
+ (
+ XComArg(
+ operator=PythonOperator(
+ python_callable=int,
+ task_id="test_xcom_op",
+ do_xcom_push=True,
+ )
+ ),
+ DAT.XCOM_REF,
+ None,
+ ),
+ (Dataset(uri="test"), DAT.DATASET, equals),
+ (SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals),
+ (
+ Connection(conn_id="TEST_ID", uri="mysql://"),
+ DAT.CONNECTION,
+ lambda a, b: a.get_uri() == b.get_uri(),
+ ),
+ ],
+)
+def test_serialize_deserialize(input, encoded_type, cmp_func):
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ serialized = BaseSerialization.serialize(input) # does not raise
+ json.dumps(serialized) # does not raise
+ if encoded_type is not None:
+ assert serialized["__type"] == encoded_type
+ assert serialized["__var"] is not None
+ if cmp_func is not None:
+ deserialized = BaseSerialization.deserialize(serialized)
+ assert cmp_func(input, deserialized)
+
+ # Verify recursive behavior
+ obj = [[input]]
+ serialized = BaseSerialization.serialize(obj) # does not raise
+ # Verify the result is JSON-serializable
+ json.dumps(serialized) # does not raise
+
+
@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
-def test_use_pydantic_models():
- """If use_pydantic_models=True the TaskInstance object should be
serialized to TaskInstancePydantic."""
[email protected](
+ "input, pydantic_class, encoded_type, cmp_func",
+ [
+ (
+ Job(state=State.RUNNING, latest_heartbeat=datetime.utcnow()),
+ JobPydantic,
+ DAT.BASE_JOB,
+ lambda a, b: equal_time(a.latest_heartbeat, b.latest_heartbeat),
+ ),
+ (
+ TI_WITH_START_DAY,
+ TaskInstancePydantic,
+ DAT.TASK_INSTANCE,
+ lambda a, b: equal_time(a.start_date, b.start_date),
+ ),
+ (
+ DAG_RUN,
+ DagRunPydantic,
+ DAT.DAG_RUN,
+ lambda a, b: equal_time(a.execution_date, b.execution_date)
+ and equal_time(a.start_date, b.start_date),
+ ),
+ # DataSet is already serialized by non-Pydantic serialization. Is
DatasetPydantic needed then?
+ # (
+ # Dataset(
+ # uri="foo://bar",
+ # extra={"foo": "bar"},
+ # ),
+ # DatasetPydantic,
+ # DAT.DATA_SET,
+ # lambda a, b: a.uri == b.uri and a.extra == b.extra,
+ # ),
+ (
+ DagModel(
+ dag_id="TEST_DAG_1",
+ fileloc="/tmp/dag_1.py",
+ schedule_interval="2 2 * * *",
+ is_paused=True,
+ ),
+ DagModelPydantic,
+ DAT.DAG_MODEL,
+ lambda a, b: a.fileloc == b.fileloc and a.schedule_interval ==
b.schedule_interval,
+ ),
+ ],
+)
+def test_serialize_deserialize_pydantic(input, pydantic_class, encoded_type,
cmp_func):
+ """If use_pydantic_models=True the objects should be serialized to
Pydantic objects."""
from airflow.serialization.serialized_objects import BaseSerialization
- ti = TaskInstance(
- task=EmptyOperator(task_id="task"),
- run_id="run_id",
- state=State.RUNNING,
- )
- start_date = datetime.utcnow()
- ti.start_date = start_date
- obj = [[ti]] # nested to verify recursive behavior
-
- serialized = BaseSerialization.serialize(obj, use_pydantic_models=True) #
does not raise
- deserialized = BaseSerialization.deserialize(serialized,
use_pydantic_models=True) # does not raise
- assert isinstance(deserialized[0][0], TaskInstancePydantic)
-
- serialized_json = json.dumps(serialized) # does not raise
- deserialized_from_json = BaseSerialization.deserialize(
- json.loads(serialized_json), use_pydantic_models=True
- ) # does not raise
- assert isinstance(deserialized_from_json[0][0], TaskInstancePydantic)
- assert deserialized_from_json[0][0].start_date == start_date
+ serialized = BaseSerialization.serialize(input, use_pydantic_models=True)
# does not raise
+ # Verify the result is JSON-serializable
+ json.dumps(serialized) # does not raise
+ assert serialized["__type"] == encoded_type
+ assert serialized["__var"] is not None
+ deserialized = BaseSerialization.deserialize(serialized,
use_pydantic_models=True)
+ assert isinstance(deserialized, pydantic_class)
+ assert cmp_func(input, deserialized)
+
+ # Verify recursive behavior
+ obj = [[input]]
+ BaseSerialization.serialize(obj, use_pydantic_models=True) # does not
raise
def test_serialized_mapped_operator_unmap(dag_maker):