This is an automated email from the ASF dual-hosted git repository.
xqhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c7edbb3dda1 Catch union-of-iterables case in get_yielded_type()
(#34186)
c7edbb3dda1 is described below
commit c7edbb3dda13807c88a4397574b9b016d9dff19f
Author: Jack McCluskey <[email protected]>
AuthorDate: Wed Mar 5 17:24:56 2025 -0500
Catch union-of-iterables case in get_yielded_type() (#34186)
* Catch union-of-iterables case in get_yielded_type()
* add test case for mixed union
---
sdks/python/apache_beam/typehints/typehints.py | 5 +++++
sdks/python/apache_beam/typehints/typehints_test.py | 8 ++++++++
2 files changed, 13 insertions(+)
diff --git a/sdks/python/apache_beam/typehints/typehints.py
b/sdks/python/apache_beam/typehints/typehints.py
index d113f3bfa6b..51b1b1ca68d 100644
--- a/sdks/python/apache_beam/typehints/typehints.py
+++ b/sdks/python/apache_beam/typehints/typehints.py
@@ -1401,6 +1401,11 @@ def get_yielded_type(type_hint):
else: # TupleSequenceConstraint
return type_hint.inner_type
if is_consistent_with(type_hint, Iterable[Any]):
+ if isinstance(type_hint, UnionConstraint):
+ yielded_types = set()
+ for typ in type_hint.inner_types():
+ yielded_types.add(get_yielded_type(typ))
+ return Union[yielded_types]
return type_hint.inner_type
raise ValueError('%s is not iterable' % type_hint)
diff --git a/sdks/python/apache_beam/typehints/typehints_test.py
b/sdks/python/apache_beam/typehints/typehints_test.py
index 3104c406f15..a81da5abec4 100644
--- a/sdks/python/apache_beam/typehints/typehints_test.py
+++ b/sdks/python/apache_beam/typehints/typehints_test.py
@@ -1450,11 +1450,19 @@ class TestGetYieldedType(unittest.TestCase):
typehints.get_yielded_type(typehints.Tuple[int, str]))
self.assertEqual(int, typehints.get_yielded_type(typehints.Set[int]))
self.assertEqual(int, typehints.get_yielded_type(typehints.FrozenSet[int]))
+ self.assertEqual(
+ typehints.Union[int, str],
+ typehints.get_yielded_type(
+ typehints.Union[typehints.List[int], typehints.List[str]]))
def test_not_iterable(self):
with self.assertRaisesRegex(ValueError, r'not iterable'):
typehints.get_yielded_type(int)
+ def test_union_not_iterable(self):
+ with self.assertRaisesRegex(ValueError, r'not iterable'):
+ typehints.get_yielded_type(typehints.Union[int, typehints.List[int]])
+
class TestCoerceToKvType(TypeHintTestCase):
def test_coercion_success(self):