This is an automated email from the ASF dual-hosted git repository.
bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 96135580047 SerDe: Check more strictly for pydantic model (#56758)
96135580047 is described below
commit 96135580047e5e0c56e9f62d7d6051c86a286821
Author: Desdroid <[email protected]>
AuthorDate: Fri Oct 31 12:22:55 2025 +0100
SerDe: Check more strictly for pydantic model (#56758)
* Check more strictly for pydantic models. Ensure that pydantic dataclasses
are not detected as pydantic models. (#56739)
* Move test to other file. (#56739)
---
airflow-core/src/airflow/serialization/typing.py | 8 +++++++-
.../serialization/serializers/test_serializers.py | 22 ++++++++++++++++++++++
2 files changed, 29 insertions(+), 1 deletion(-)
diff --git a/airflow-core/src/airflow/serialization/typing.py
b/airflow-core/src/airflow/serialization/typing.py
index a6169b23a78..35166710b78 100644
--- a/airflow-core/src/airflow/serialization/typing.py
+++ b/airflow-core/src/airflow/serialization/typing.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+from dataclasses import is_dataclass
from typing import Any
@@ -29,4 +30,9 @@ def is_pydantic_model(cls: Any) -> bool:
"""
# __pydantic_fields__ is always present on Pydantic V2 models and is a
dict[str, FieldInfo]
# __pydantic_validator__ is an internal validator object, always set after
model build
- return hasattr(cls, "__pydantic_fields__") and hasattr(cls,
"__pydantic_validator__")
+ # Check if it is not a dataclass to prevent detecting pydantic dataclasses
as pydantic models
+ return (
+ hasattr(cls, "__pydantic_fields__")
+ and hasattr(cls, "__pydantic_validator__")
+ and not is_dataclass(cls)
+ )
diff --git
a/airflow-core/tests/unit/serialization/serializers/test_serializers.py
b/airflow-core/tests/unit/serialization/serializers/test_serializers.py
index deabcf0f4f1..cd206d667d5 100644
--- a/airflow-core/tests/unit/serialization/serializers/test_serializers.py
+++ b/airflow-core/tests/unit/serialization/serializers/test_serializers.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import datetime
import decimal
from importlib import metadata
+from typing import ClassVar
from unittest.mock import patch
from zoneinfo import ZoneInfo
@@ -33,10 +34,12 @@ from packaging import version
from pendulum import DateTime
from pendulum.tz.timezone import FixedTimezone, Timezone
from pydantic import BaseModel, Field
+from pydantic.dataclasses import dataclass as pydantic_dataclass
from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.serialization.serde import CLASSNAME, DATA, VERSION, _stringify,
decode, deserialize, serialize
from airflow.serialization.serializers import builtin
+from airflow.utils.module_loading import qualname
from tests_common.test_utils.markers import
skip_if_force_lowest_dependencies_marker
@@ -68,6 +71,13 @@ class FooBarModel(BaseModel):
foo: str = Field()
+@pydantic_dataclass
+class PydanticDataclass:
+ __version__: ClassVar[int] = 1
+ a: int
+ b: str
+
+
@skip_if_force_lowest_dependencies_marker
class TestSerializers:
@pytest.mark.parametrize(
@@ -389,6 +399,18 @@ class TestSerializers:
with pytest.raises(TypeError, match=msg):
deserialize(klass, version, data)
+ def test_pydantic_dataclass(self):
+ orig = PydanticDataclass(a=5, b="SerDe Pydantic Dataclass Test")
+ serialized = serialize(orig)
+ assert orig.__version__ == serialized[VERSION]
+ assert qualname(orig) == serialized[CLASSNAME]
+ assert serialized[DATA]
+
+ decoded = deserialize(serialized)
+ assert decoded.a == orig.a
+ assert decoded.b == orig.b
+ assert type(decoded) is type(orig)
+
@pytest.mark.skipif(not PENDULUM3, reason="Test case for pendulum~=3")
@pytest.mark.parametrize(
"ser_value, expected",