This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 63e8727e23e Port over custom serializer tests from airflow core to 
task sdk (#59638)
63e8727e23e is described below

commit 63e8727e23e1574e56ef969de07c2f43de201eff
Author: Amogh Desai <[email protected]>
AuthorDate: Fri Dec 19 23:24:58 2025 +0530

    Port over custom serializer tests from airflow core to task sdk (#59638)
---
 .../unit/serialization/test_dag_serialization.py   | 12 ++++++
 task-sdk/tests/conftest.py                         | 13 ++++++
 task-sdk/tests/task_sdk/serde/test_serde.py        | 11 -----
 .../tests/task_sdk/serde}/test_serializers.py      | 48 ++++++++--------------
 4 files changed, 43 insertions(+), 41 deletions(-)

diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py 
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index 81f7f398ed3..9c802bdce98 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -4102,6 +4102,18 @@ class TestSchemaDefaults:
         overlap = optional_fields & required_fields
         assert not overlap, f"Optional fields should not overlap with required 
fields: {overlap}"
 
+    def test_json_schema_load_dag_schema_dict(self, monkeypatch):
+        """Test error handling when schema file is missing."""
+        from airflow.exceptions import AirflowException
+
+        monkeypatch.setattr(
+            "airflow.serialization.json_schema.pkgutil.get_data", lambda 
__name__, fname: None
+        )
+
+        with pytest.raises(AirflowException) as ctx:
+            load_dag_schema_dict()
+        assert "Schema file schema.json does not exists" in str(ctx.value)
+
 
 class TestDeserializationDefaultsResolution:
     """Test defaults resolution during deserialization."""
diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py
index 55d0094089c..6ae329e5345 100644
--- a/task-sdk/tests/conftest.py
+++ b/task-sdk/tests/conftest.py
@@ -23,6 +23,8 @@ from typing import TYPE_CHECKING, Any, NoReturn, Protocol
 
 import pytest
 
+from tests_common.test_utils.config import conf_vars
+
 pytest_plugins = "tests_common.pytest_plugin"
 
 # Task SDK does not need access to the Airflow database
@@ -335,3 +337,14 @@ def make_ti_context_dict(make_ti_context: 
MakeTIContextCallable) -> MakeTIContex
         return context.model_dump(exclude_unset=True, mode="json")
 
     return _make_context_dict
+
+
[email protected](scope="class", autouse=True)
+def allow_test_classes_deserialization():
+    """
+    Allow test classes and airflow SDK classes to be deserialized. In 
airflow-core tests, this is provided by
+    unit_tests.cfg which sets allowed_deserialization_classes = airflow.* 
tests.*
+    SDK tests may not inherit that configuration, so we explicitly allow 
airflow.sdk.* and tests.* here.
+    """
+    with conf_vars({("core", "allowed_deserialization_classes"): 
"airflow.sdk.* tests.*"}):
+        yield
diff --git a/task-sdk/tests/task_sdk/serde/test_serde.py 
b/task-sdk/tests/task_sdk/serde/test_serde.py
index 9b22191a2aa..e74d1fd5d0e 100644
--- a/task-sdk/tests/task_sdk/serde/test_serde.py
+++ b/task-sdk/tests/task_sdk/serde/test_serde.py
@@ -196,17 +196,6 @@ class C:
         return None
 
 
[email protected](scope="class", autouse=True)
-def allow_test_classes_deserialization():
-    """
-    Allow test classes and airflow SDK classes to be deserialized. In 
airflow-core tests, this is provided by
-    unit_tests.cfg which sets allowed_deserialization_classes = airflow.* 
tests.*
-    SDK tests may not inherit that configuration, so we explicitly allow 
airflow.sdk.* and tests.* here.
-    """
-    with conf_vars({("core", "allowed_deserialization_classes"): 
"airflow.sdk.* tests.*"}):
-        yield
-
-
 @pytest.mark.usefixtures("recalculate_patterns")
 class TestSerDe:
     def test_ser_primitives(self):
diff --git 
a/airflow-core/tests/unit/serialization/serializers/test_serializers.py 
b/task-sdk/tests/task_sdk/serde/test_serializers.py
similarity index 93%
rename from 
airflow-core/tests/unit/serialization/serializers/test_serializers.py
rename to task-sdk/tests/task_sdk/serde/test_serializers.py
index 0873fef6d05..58a9e701866 100644
--- a/airflow-core/tests/unit/serialization/serializers/test_serializers.py
+++ b/task-sdk/tests/task_sdk/serde/test_serializers.py
@@ -36,10 +36,10 @@ from pendulum.tz.timezone import FixedTimezone, Timezone
 from pydantic import BaseModel, Field
 from pydantic.dataclasses import dataclass as pydantic_dataclass
 
-from airflow._shared.module_loading import qualname
+from airflow.sdk._shared.module_loading import qualname
 from airflow.sdk.definitions.param import Param, ParamsDict
 from airflow.sdk.serde import CLASSNAME, DATA, VERSION, decode, deserialize, 
serialize
-from airflow.serialization.serializers import builtin
+from airflow.sdk.serde.serializers import builtin
 
 from tests_common.test_utils.markers import 
skip_if_force_lowest_dependencies_marker
 
@@ -170,12 +170,12 @@ class TestSerializers:
         }
 
     def test_bignum_serialize_non_decimal(self):
-        from airflow.serialization.serializers.bignum import serialize
+        from airflow.sdk.serde.serializers.bignum import serialize
 
         assert serialize(12345) == ("", "", 0, False)
 
     def test_bignum_deserialize_decimal(self):
-        from airflow.serialization.serializers.bignum import deserialize
+        from airflow.sdk.serde.serializers.bignum import deserialize
 
         res = deserialize(decimal.Decimal, 1, decimal.Decimal(12345))
         assert res == decimal.Decimal(12345)
@@ -198,7 +198,7 @@ class TestSerializers:
         ],
     )
     def test_bignum_deserialize_errors(self, klass, version, payload, msg):
-        from airflow.serialization.serializers.bignum import deserialize
+        from airflow.sdk.serde.serializers.bignum import deserialize
 
         with pytest.raises(TypeError, match=msg):
             deserialize(klass, version, payload)
@@ -221,7 +221,7 @@ class TestSerializers:
         assert type(value) is type(d)
 
     def test_numpy_serializers(self):
-        from airflow.serialization.serializers.numpy import serialize
+        from airflow.sdk.serde.serializers.numpy import serialize
 
         numpy_version = metadata.version("numpy")
         is_numpy_2 = version.parse(numpy_version).major == 2
@@ -240,7 +240,7 @@ class TestSerializers:
         ],
     )
     def test_numpy_deserialize_errors(self, klass, ver, value, msg):
-        from airflow.serialization.serializers.numpy import deserialize
+        from airflow.sdk.serde.serializers.numpy import deserialize
 
         with pytest.raises(TypeError, match=msg):
             deserialize(klass, ver, value)
@@ -260,7 +260,7 @@ class TestSerializers:
         assert i.equals(d)
 
     def test_pandas_serializers(self):
-        from airflow.serialization.serializers.pandas import serialize
+        from airflow.sdk.serde.serializers.pandas import serialize
 
         assert serialize(123) == ("", "", 0, False)
 
@@ -278,7 +278,7 @@ class TestSerializers:
         ],
     )
     def test_pandas_deserialize_errors(self, klass, version, data, msg):
-        from airflow.serialization.serializers.pandas import deserialize
+        from airflow.sdk.serde.serializers.pandas import deserialize
 
         with pytest.raises(TypeError, match=msg):
             deserialize(klass, version, data)
@@ -337,7 +337,7 @@ class TestSerializers:
             assert d._storage_options is None
 
     def test_deltalake_serialize_deserialize(self):
-        from airflow.serialization.serializers.deltalake import serialize
+        from airflow.sdk.serde.serializers.deltalake import serialize
 
         assert serialize(object()) == ("", "", 0, False)
 
@@ -359,14 +359,14 @@ class TestSerializers:
         ],
     )
     def test_deltalake_deserialize_errors(self, klass, version, payload, msg):
-        from airflow.serialization.serializers.deltalake import deserialize
+        from airflow.sdk.serde.serializers.deltalake import deserialize
 
         with pytest.raises(TypeError, match=msg):
             deserialize(klass, version, payload)
 
     def test_kubernetes_serializer(self, monkeypatch):
         from airflow.providers.cncf.kubernetes.pod_generator import 
PodGenerator
-        from airflow.serialization.serializers.kubernetes import serialize
+        from airflow.sdk.serde.serializers.kubernetes import serialize
 
         pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="foo"))
         monkeypatch.setattr(PodGenerator, "serialize_pod", lambda o: (_ for _ 
in ()).throw(Exception("fail")))
@@ -394,7 +394,7 @@ class TestSerializers:
         ],
     )
     def test_pydantic_deserialize_errors(self, klass, version, data, msg):
-        from airflow.serialization.serializers.pydantic import deserialize
+        from airflow.sdk.serde.serializers.pydantic import deserialize
 
         with pytest.raises(TypeError, match=msg):
             deserialize(klass, version, data)
@@ -560,17 +560,17 @@ class TestSerializers:
         assert deserialize(ser_value) == expected
 
     def test_timezone_serialize_fixed(self):
-        from airflow.serialization.serializers.timezone import serialize
+        from airflow.sdk.serde.serializers.timezone import serialize
 
         assert serialize(FixedTimezone(0)) == ("UTC", 
"pendulum.tz.timezone.FixedTimezone", 1, True)
 
     def test_timezone_serialize_no_name(self):
-        from airflow.serialization.serializers.timezone import serialize
+        from airflow.sdk.serde.serializers.timezone import serialize
 
         assert serialize(NoNameTZ()) == ("", "", 0, False)
 
     def test_timezone_deserialize_zoneinfo(self):
-        from airflow.serialization.serializers.timezone import deserialize
+        from airflow.sdk.serde.serializers.timezone import deserialize
 
         zi = deserialize(ZoneInfo, 1, "Asia/Taipei")
         assert isinstance(zi, ZoneInfo)
@@ -584,7 +584,7 @@ class TestSerializers:
         ],
     )
     def test_timezone_deserialize_errors(self, klass, version, data, msg):
-        from airflow.serialization.serializers.timezone import deserialize
+        from airflow.sdk.serde.serializers.timezone import deserialize
 
         with pytest.raises(TypeError, match=msg):
             deserialize(klass, version, data)
@@ -598,22 +598,10 @@ class TestSerializers:
         ],
     )
     def test_timezone_get_tzinfo_name(self, tz_obj, expected):
-        from airflow.serialization.serializers.timezone import _get_tzinfo_name
+        from airflow.sdk.serde.serializers.timezone import _get_tzinfo_name
 
         assert _get_tzinfo_name(tz_obj) == expected
 
-    def test_json_schema_load_dag_schema_dict(self, monkeypatch):
-        from airflow.exceptions import AirflowException
-        from airflow.serialization.json_schema import load_dag_schema_dict
-
-        monkeypatch.setattr(
-            "airflow.serialization.json_schema.pkgutil.get_data", lambda 
__name__, fname: None
-        )
-
-        with pytest.raises(AirflowException) as ctx:
-            load_dag_schema_dict()
-        assert "Schema file schema.json does not exists" in str(ctx.value)
-
     @pytest.mark.parametrize(
         ("klass", "version", "data"),
         [(tuple, 1, [11, 12]), (set, 1, [11, 12]), (frozenset, 1, [11, 12])],

Reply via email to