This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 63e8727e23e Port over custom serializer tests from airflow core to
task sdk (#59638)
63e8727e23e is described below
commit 63e8727e23e1574e56ef969de07c2f43de201eff
Author: Amogh Desai <[email protected]>
AuthorDate: Fri Dec 19 23:24:58 2025 +0530
Port over custom serializer tests from airflow core to task sdk (#59638)
---
.../unit/serialization/test_dag_serialization.py | 12 ++++++
task-sdk/tests/conftest.py | 13 ++++++
task-sdk/tests/task_sdk/serde/test_serde.py | 11 -----
.../tests/task_sdk/serde}/test_serializers.py | 48 ++++++++--------------
4 files changed, 43 insertions(+), 41 deletions(-)
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index 81f7f398ed3..9c802bdce98 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -4102,6 +4102,18 @@ class TestSchemaDefaults:
overlap = optional_fields & required_fields
assert not overlap, f"Optional fields should not overlap with required
fields: {overlap}"
+ def test_json_schema_load_dag_schema_dict(self, monkeypatch):
+ """Test error handling when schema file is missing."""
+ from airflow.exceptions import AirflowException
+
+ monkeypatch.setattr(
+ "airflow.serialization.json_schema.pkgutil.get_data", lambda
__name__, fname: None
+ )
+
+ with pytest.raises(AirflowException) as ctx:
+ load_dag_schema_dict()
+ assert "Schema file schema.json does not exists" in str(ctx.value)
+
class TestDeserializationDefaultsResolution:
"""Test defaults resolution during deserialization."""
diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py
index 55d0094089c..6ae329e5345 100644
--- a/task-sdk/tests/conftest.py
+++ b/task-sdk/tests/conftest.py
@@ -23,6 +23,8 @@ from typing import TYPE_CHECKING, Any, NoReturn, Protocol
import pytest
+from tests_common.test_utils.config import conf_vars
+
pytest_plugins = "tests_common.pytest_plugin"
# Task SDK does not need access to the Airflow database
@@ -335,3 +337,14 @@ def make_ti_context_dict(make_ti_context:
MakeTIContextCallable) -> MakeTIContex
return context.model_dump(exclude_unset=True, mode="json")
return _make_context_dict
+
+
[email protected](scope="class", autouse=True)
+def allow_test_classes_deserialization():
+ """
+ Allow test classes and airflow SDK classes to be deserialized. In
airflow-core tests, this is provided by
+ unit_tests.cfg which sets allowed_deserialization_classes = airflow.*
tests.*
+ SDK tests may not inherit that configuration, so we explicitly allow
airflow.sdk.* and tests.* here.
+ """
+ with conf_vars({("core", "allowed_deserialization_classes"):
"airflow.sdk.* tests.*"}):
+ yield
diff --git a/task-sdk/tests/task_sdk/serde/test_serde.py
b/task-sdk/tests/task_sdk/serde/test_serde.py
index 9b22191a2aa..e74d1fd5d0e 100644
--- a/task-sdk/tests/task_sdk/serde/test_serde.py
+++ b/task-sdk/tests/task_sdk/serde/test_serde.py
@@ -196,17 +196,6 @@ class C:
return None
[email protected](scope="class", autouse=True)
-def allow_test_classes_deserialization():
- """
- Allow test classes and airflow SDK classes to be deserialized. In
airflow-core tests, this is provided by
- unit_tests.cfg which sets allowed_deserialization_classes = airflow.*
tests.*
- SDK tests may not inherit that configuration, so we explicitly allow
airflow.sdk.* and tests.* here.
- """
- with conf_vars({("core", "allowed_deserialization_classes"):
"airflow.sdk.* tests.*"}):
- yield
-
-
@pytest.mark.usefixtures("recalculate_patterns")
class TestSerDe:
def test_ser_primitives(self):
diff --git
a/airflow-core/tests/unit/serialization/serializers/test_serializers.py
b/task-sdk/tests/task_sdk/serde/test_serializers.py
similarity index 93%
rename from
airflow-core/tests/unit/serialization/serializers/test_serializers.py
rename to task-sdk/tests/task_sdk/serde/test_serializers.py
index 0873fef6d05..58a9e701866 100644
--- a/airflow-core/tests/unit/serialization/serializers/test_serializers.py
+++ b/task-sdk/tests/task_sdk/serde/test_serializers.py
@@ -36,10 +36,10 @@ from pendulum.tz.timezone import FixedTimezone, Timezone
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass as pydantic_dataclass
-from airflow._shared.module_loading import qualname
+from airflow.sdk._shared.module_loading import qualname
from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.sdk.serde import CLASSNAME, DATA, VERSION, decode, deserialize,
serialize
-from airflow.serialization.serializers import builtin
+from airflow.sdk.serde.serializers import builtin
from tests_common.test_utils.markers import
skip_if_force_lowest_dependencies_marker
@@ -170,12 +170,12 @@ class TestSerializers:
}
def test_bignum_serialize_non_decimal(self):
- from airflow.serialization.serializers.bignum import serialize
+ from airflow.sdk.serde.serializers.bignum import serialize
assert serialize(12345) == ("", "", 0, False)
def test_bignum_deserialize_decimal(self):
- from airflow.serialization.serializers.bignum import deserialize
+ from airflow.sdk.serde.serializers.bignum import deserialize
res = deserialize(decimal.Decimal, 1, decimal.Decimal(12345))
assert res == decimal.Decimal(12345)
@@ -198,7 +198,7 @@ class TestSerializers:
],
)
def test_bignum_deserialize_errors(self, klass, version, payload, msg):
- from airflow.serialization.serializers.bignum import deserialize
+ from airflow.sdk.serde.serializers.bignum import deserialize
with pytest.raises(TypeError, match=msg):
deserialize(klass, version, payload)
@@ -221,7 +221,7 @@ class TestSerializers:
assert type(value) is type(d)
def test_numpy_serializers(self):
- from airflow.serialization.serializers.numpy import serialize
+ from airflow.sdk.serde.serializers.numpy import serialize
numpy_version = metadata.version("numpy")
is_numpy_2 = version.parse(numpy_version).major == 2
@@ -240,7 +240,7 @@ class TestSerializers:
],
)
def test_numpy_deserialize_errors(self, klass, ver, value, msg):
- from airflow.serialization.serializers.numpy import deserialize
+ from airflow.sdk.serde.serializers.numpy import deserialize
with pytest.raises(TypeError, match=msg):
deserialize(klass, ver, value)
@@ -260,7 +260,7 @@ class TestSerializers:
assert i.equals(d)
def test_pandas_serializers(self):
- from airflow.serialization.serializers.pandas import serialize
+ from airflow.sdk.serde.serializers.pandas import serialize
assert serialize(123) == ("", "", 0, False)
@@ -278,7 +278,7 @@ class TestSerializers:
],
)
def test_pandas_deserialize_errors(self, klass, version, data, msg):
- from airflow.serialization.serializers.pandas import deserialize
+ from airflow.sdk.serde.serializers.pandas import deserialize
with pytest.raises(TypeError, match=msg):
deserialize(klass, version, data)
@@ -337,7 +337,7 @@ class TestSerializers:
assert d._storage_options is None
def test_deltalake_serialize_deserialize(self):
- from airflow.serialization.serializers.deltalake import serialize
+ from airflow.sdk.serde.serializers.deltalake import serialize
assert serialize(object()) == ("", "", 0, False)
@@ -359,14 +359,14 @@ class TestSerializers:
],
)
def test_deltalake_deserialize_errors(self, klass, version, payload, msg):
- from airflow.serialization.serializers.deltalake import deserialize
+ from airflow.sdk.serde.serializers.deltalake import deserialize
with pytest.raises(TypeError, match=msg):
deserialize(klass, version, payload)
def test_kubernetes_serializer(self, monkeypatch):
from airflow.providers.cncf.kubernetes.pod_generator import
PodGenerator
- from airflow.serialization.serializers.kubernetes import serialize
+ from airflow.sdk.serde.serializers.kubernetes import serialize
pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="foo"))
monkeypatch.setattr(PodGenerator, "serialize_pod", lambda o: (_ for _
in ()).throw(Exception("fail")))
@@ -394,7 +394,7 @@ class TestSerializers:
],
)
def test_pydantic_deserialize_errors(self, klass, version, data, msg):
- from airflow.serialization.serializers.pydantic import deserialize
+ from airflow.sdk.serde.serializers.pydantic import deserialize
with pytest.raises(TypeError, match=msg):
deserialize(klass, version, data)
@@ -560,17 +560,17 @@ class TestSerializers:
assert deserialize(ser_value) == expected
def test_timezone_serialize_fixed(self):
- from airflow.serialization.serializers.timezone import serialize
+ from airflow.sdk.serde.serializers.timezone import serialize
assert serialize(FixedTimezone(0)) == ("UTC",
"pendulum.tz.timezone.FixedTimezone", 1, True)
def test_timezone_serialize_no_name(self):
- from airflow.serialization.serializers.timezone import serialize
+ from airflow.sdk.serde.serializers.timezone import serialize
assert serialize(NoNameTZ()) == ("", "", 0, False)
def test_timezone_deserialize_zoneinfo(self):
- from airflow.serialization.serializers.timezone import deserialize
+ from airflow.sdk.serde.serializers.timezone import deserialize
zi = deserialize(ZoneInfo, 1, "Asia/Taipei")
assert isinstance(zi, ZoneInfo)
@@ -584,7 +584,7 @@ class TestSerializers:
],
)
def test_timezone_deserialize_errors(self, klass, version, data, msg):
- from airflow.serialization.serializers.timezone import deserialize
+ from airflow.sdk.serde.serializers.timezone import deserialize
with pytest.raises(TypeError, match=msg):
deserialize(klass, version, data)
@@ -598,22 +598,10 @@ class TestSerializers:
],
)
def test_timezone_get_tzinfo_name(self, tz_obj, expected):
- from airflow.serialization.serializers.timezone import _get_tzinfo_name
+ from airflow.sdk.serde.serializers.timezone import _get_tzinfo_name
assert _get_tzinfo_name(tz_obj) == expected
- def test_json_schema_load_dag_schema_dict(self, monkeypatch):
- from airflow.exceptions import AirflowException
- from airflow.serialization.json_schema import load_dag_schema_dict
-
- monkeypatch.setattr(
- "airflow.serialization.json_schema.pkgutil.get_data", lambda
__name__, fname: None
- )
-
- with pytest.raises(AirflowException) as ctx:
- load_dag_schema_dict()
- assert "Schema file schema.json does not exists" in str(ctx.value)
-
@pytest.mark.parametrize(
("klass", "version", "data"),
[(tuple, 1, [11, 12]), (set, 1, [11, 12]), (frozenset, 1, [11, 12])],