This is an automated email from the ASF dual-hosted git repository.
yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 9f5904b2b59 Support inferring schemas from Python dataclasses (#37728)
9f5904b2b59 is described below
commit 9f5904b2b59c6b75d9c161ee7bc1828cbd08e8f4
Author: Yi Hu <[email protected]>
AuthorDate: Wed Mar 11 12:16:49 2026 -0400
Support inferring schemas from Python dataclasses (#37728)
* Support inferring schemas from Python dataclasses
* Address comments; Revert native_type_compatibility _TypeMapEntry change
* Add unit test for named tuple and dataclasses encoded by RowCoder and
passing through GBK
* Fix lint
---
sdks/python/apache_beam/coders/coder_impl.py | 6 +-
.../typehints/native_type_compatibility.py | 5 ++
sdks/python/apache_beam/typehints/row_type.py | 47 +++++------
sdks/python/apache_beam/typehints/row_type_test.py | 90 ++++++++++++++++++++++
sdks/python/apache_beam/typehints/schemas.py | 3 +-
sdks/python/apache_beam/typehints/schemas_test.py | 61 +++++++++++++++
6 files changed, 183 insertions(+), 29 deletions(-)
diff --git a/sdks/python/apache_beam/coders/coder_impl.py
b/sdks/python/apache_beam/coders/coder_impl.py
index 3e0b5218b16..b3e45bc7f35 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -30,6 +30,7 @@ For internal use only; no backwards-compatibility guarantees.
"""
# pytype: skip-file
+import dataclasses
import decimal
import enum
import itertools
@@ -67,11 +68,6 @@ from apache_beam.utils.timestamp import MAX_TIMESTAMP
from apache_beam.utils.timestamp import MIN_TIMESTAMP
from apache_beam.utils.timestamp import Timestamp
-try:
- import dataclasses
-except ImportError:
- dataclasses = None # type: ignore
-
try:
import dill
except ImportError:
diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py
b/sdks/python/apache_beam/typehints/native_type_compatibility.py
index 345c04706d6..886b1505ffe 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py
@@ -21,6 +21,7 @@
import collections
import collections.abc
+import dataclasses
import logging
import sys
import types
@@ -175,6 +176,10 @@ def match_is_named_tuple(user_type):
hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields'))
+def match_is_dataclass(user_type):
+ return dataclasses.is_dataclass(user_type) and isinstance(user_type, type)
+
+
def _match_is_optional(user_type):
return _match_is_union(user_type) and sum(
tp is type(None) for tp in _get_args(user_type)) == 1
diff --git a/sdks/python/apache_beam/typehints/row_type.py
b/sdks/python/apache_beam/typehints/row_type.py
index 08838c84a05..6f96f6f64e3 100644
--- a/sdks/python/apache_beam/typehints/row_type.py
+++ b/sdks/python/apache_beam/typehints/row_type.py
@@ -19,6 +19,7 @@
from __future__ import annotations
+import dataclasses
from typing import Any
from typing import Dict
from typing import Optional
@@ -26,6 +27,7 @@ from typing import Sequence
from typing import Tuple
from apache_beam.typehints import typehints
+from apache_beam.typehints.native_type_compatibility import match_is_dataclass
from apache_beam.typehints.native_type_compatibility import
match_is_named_tuple
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
@@ -56,18 +58,14 @@ class RowTypeConstraint(typehints.TypeConstraint):
for guidance on creating PCollections with inferred schemas.
Note RowTypeConstraint does not currently store arbitrary functions for
- converting to/from the user type. Instead, we only support ``NamedTuple``
- user types and make the follow assumptions:
+ converting to/from the user type. Instead, we support ``NamedTuple`` and
+ ``dataclasses`` user types and make the follow assumptions:
- The user type can be constructed with field values as arguments in order
(i.e. ``constructor(*field_values)``).
- Field values can be accessed from instances of the user type by attribute
(i.e. with ``getattr(obj, field_name)``).
- In the future we will add support for dataclasses
- ([#22085](https://github.com/apache/beam/issues/22085)) which also satisfy
- these assumptions.
-
The RowTypeConstraint constructor should not be called directly (even
internally to Beam). Prefer static methods ``from_user_type`` or
``from_fields``.
@@ -107,27 +105,30 @@ class RowTypeConstraint(typehints.TypeConstraint):
if match_is_named_tuple(user_type):
fields = [(name, user_type.__annotations__[name])
for name in user_type._fields]
-
- field_descriptions = getattr(user_type, '_field_descriptions', None)
-
- if _user_type_is_generated(user_type):
- return RowTypeConstraint.from_fields(
- fields,
- schema_id=getattr(user_type, _BEAM_SCHEMA_ID),
- schema_options=schema_options,
- field_options=field_options,
- field_descriptions=field_descriptions)
-
- # TODO(https://github.com/apache/beam/issues/22125): Add user API for
- # specifying schema/field options
- return RowTypeConstraint(
- fields=fields,
- user_type=user_type,
+ elif match_is_dataclass(user_type):
+ fields = [(field.name, field.type)
+ for field in dataclasses.fields(user_type)]
+ else:
+ return None
+
+ field_descriptions = getattr(user_type, '_field_descriptions', None)
+
+ if _user_type_is_generated(user_type):
+ return RowTypeConstraint.from_fields(
+ fields,
+ schema_id=getattr(user_type, _BEAM_SCHEMA_ID),
schema_options=schema_options,
field_options=field_options,
field_descriptions=field_descriptions)
- return None
+ # TODO(https://github.com/apache/beam/issues/22125): Add user API for
+ # specifying schema/field options
+ return RowTypeConstraint(
+ fields=fields,
+ user_type=user_type,
+ schema_options=schema_options,
+ field_options=field_options,
+ field_descriptions=field_descriptions)
@staticmethod
def from_fields(
diff --git a/sdks/python/apache_beam/typehints/row_type_test.py
b/sdks/python/apache_beam/typehints/row_type_test.py
new file mode 100644
index 00000000000..73d76fee49c
--- /dev/null
+++ b/sdks/python/apache_beam/typehints/row_type_test.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for the Beam Row typing functionality."""
+
+import typing
+import unittest
+from dataclasses import dataclass
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.typehints import row_type
+
+
+class RowTypeTest(unittest.TestCase):
+ @staticmethod
+ def _check_key_type_and_count(x) -> int:
+ key_type = type(x[0])
+ if not row_type._user_type_is_generated(key_type):
+ raise RuntimeError("Expect type after GBK to be generated user type")
+
+ return len(x[1])
+
+ def test_group_by_key_namedtuple(self):
+ MyNamedTuple = typing.NamedTuple(
+ "MyNamedTuple", [("id", int), ("name", str)])
+
+ beam.coders.typecoders.registry.register_coder(
+ MyNamedTuple, beam.coders.RowCoder)
+
+ def generate(num: int):
+ for i in range(100):
+ yield (MyNamedTuple(i, 'a'), num)
+
+ pipeline = TestPipeline(is_integration_test=False)
+
+ with pipeline as p:
+ result = (
+ p
+ | 'Create' >> beam.Create([i for i in range(10)])
+ | 'Generate' >> beam.ParDo(generate).with_output_types(
+ tuple[MyNamedTuple, int])
+ | 'GBK' >> beam.GroupByKey()
+ | 'Count Elements' >> beam.Map(self._check_key_type_and_count))
+ assert_that(result, equal_to([10] * 100))
+
+ def test_group_by_key_dataclass(self):
+ @dataclass
+ class MyDataClass:
+ id: int
+ name: str
+
+ beam.coders.typecoders.registry.register_coder(
+ MyDataClass, beam.coders.RowCoder)
+
+ def generate(num: int):
+ for i in range(100):
+ yield (MyDataClass(i, 'a'), num)
+
+ pipeline = TestPipeline(is_integration_test=False)
+
+ with pipeline as p:
+ result = (
+ p
+ | 'Create' >> beam.Create([i for i in range(10)])
+ | 'Generate' >> beam.ParDo(generate).with_output_types(
+ tuple[MyDataClass, int])
+ | 'GBK' >> beam.GroupByKey()
+ | 'Count Elements' >> beam.Map(self._check_key_type_and_count))
+ assert_that(result, equal_to([10] * 100))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/typehints/schemas.py
b/sdks/python/apache_beam/typehints/schemas.py
index e9674fa5bc2..5dd8ff290c4 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -96,6 +96,7 @@ from apache_beam.typehints.native_type_compatibility import
_match_is_optional
from apache_beam.typehints.native_type_compatibility import _safe_issubclass
from apache_beam.typehints.native_type_compatibility import
convert_to_python_type
from apache_beam.typehints.native_type_compatibility import
extract_optional_type
+from apache_beam.typehints.native_type_compatibility import match_is_dataclass
from apache_beam.typehints.native_type_compatibility import
match_is_named_tuple
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
@@ -629,7 +630,7 @@ def schema_from_element_type(element_type: type) ->
schema_pb2.Schema:
Returns schema as a list of (name, python_type) tuples"""
if isinstance(element_type, row_type.RowTypeConstraint):
return named_fields_to_schema(element_type._fields)
- elif match_is_named_tuple(element_type):
+ elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
# if the named tuple's schema is in registry, we just use it instead of
# regenerating one.
diff --git a/sdks/python/apache_beam/typehints/schemas_test.py
b/sdks/python/apache_beam/typehints/schemas_test.py
index 73db06b9a8d..5a5d7396ab3 100644
--- a/sdks/python/apache_beam/typehints/schemas_test.py
+++ b/sdks/python/apache_beam/typehints/schemas_test.py
@@ -19,6 +19,7 @@
# pytype: skip-file
+import dataclasses
import itertools
import pickle
import unittest
@@ -388,6 +389,24 @@ class SchemaTest(unittest.TestCase):
self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
self.assert_namedtuple_equivalent(roundtripped.user_type, user_type)
+ def test_dataclass_roundtrip(self):
+ @dataclasses.dataclass
+ class SimpleDataclass:
+ id: np.int64
+ name: str
+
+ roundtripped = typing_from_runner_api(
+ typing_to_runner_api(
+ SimpleDataclass, schema_registry=SchemaTypeRegistry()),
+ schema_registry=SchemaTypeRegistry())
+
+ self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
+ # The roundtripped user_type is generated as a NamedTuple, so we can't test
+ # equivalence directly with the dataclass.
+ # Instead, let's verify annotations.
+ self.assertEqual(
+ roundtripped.user_type.__annotations__,
SimpleDataclass.__annotations__)
+
def test_row_type_constraint_to_schema(self):
result_type = typing_to_runner_api(
row_type.RowTypeConstraint.from_fields([
@@ -646,6 +665,48 @@ class SchemaTest(unittest.TestCase):
expected.row_type.schema.fields,
typing_to_runner_api(MyCuteClass).row_type.schema.fields)
+ def test_trivial_example_dataclass(self):
+ @dataclasses.dataclass
+ class MyCuteDataclass:
+ name: str
+ age: Optional[int]
+ interests: List[str]
+ height: float
+ blob: ByteString
+
+ expected = schema_pb2.FieldType(
+ row_type=schema_pb2.RowType(
+ schema=schema_pb2.Schema(
+ fields=[
+ schema_pb2.Field(
+ name='name',
+ type=schema_pb2.FieldType(
+ atomic_type=schema_pb2.STRING),
+ ),
+ schema_pb2.Field(
+ name='age',
+ type=schema_pb2.FieldType(
+ nullable=True, atomic_type=schema_pb2.INT64)),
+ schema_pb2.Field(
+ name='interests',
+ type=schema_pb2.FieldType(
+ array_type=schema_pb2.ArrayType(
+ element_type=schema_pb2.FieldType(
+ atomic_type=schema_pb2.STRING)))),
+ schema_pb2.Field(
+ name='height',
+ type=schema_pb2.FieldType(
+ atomic_type=schema_pb2.DOUBLE)),
+ schema_pb2.Field(
+ name='blob',
+ type=schema_pb2.FieldType(
+ atomic_type=schema_pb2.BYTES)),
+ ])))
+
+ self.assertEqual(
+ expected.row_type.schema.fields,
+ typing_to_runner_api(MyCuteDataclass).row_type.schema.fields)
+
def test_user_type_annotated_with_id_after_conversion(self):
MyCuteClass = NamedTuple('MyCuteClass', [
('name', str),