This is an automated email from the ASF dual-hosted git repository.
cvandermerwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new fed1b551c35 Add support for tagged output type hints. (#37434)
fed1b551c35 is described below
commit fed1b551c353b588f764a293b09a55cce997859a
Author: claudevdm <[email protected]>
AuthorDate: Wed Feb 4 17:14:07 2026 -0500
Add support for tagged output type hints. (#37434)
* Add tagged typehint support.
* Just warn when bare tagged output
* Remove contains tagged output check.
* Mapped bare TaggedOutput to Any
* Extract tagged outputs after strip_iterable.
---
sdks/python/apache_beam/pvalue.py | 12 +-
sdks/python/apache_beam/transforms/core.py | 41 ++-
sdks/python/apache_beam/transforms/ptransform.py | 28 +-
sdks/python/apache_beam/typehints/decorators.py | 182 +++++++++--
.../apache_beam/typehints/decorators_test.py | 125 ++++++++
.../typehints/tagged_output_typehints_test.py | 356 +++++++++++++++++++++
.../python/apache_beam/typehints/typehints_test.py | 2 +-
7 files changed, 712 insertions(+), 34 deletions(-)
diff --git a/sdks/python/apache_beam/pvalue.py
b/sdks/python/apache_beam/pvalue.py
index 6621d96127d..d09a0040bd7 100644
--- a/sdks/python/apache_beam/pvalue.py
+++ b/sdks/python/apache_beam/pvalue.py
@@ -265,6 +265,8 @@ class DoOutputsTuple(object):
self._tags = tags
self._main_tag = main_tag
self._transform = transform
+ self._tagged_output_types = (
+ transform.get_type_hints().tagged_output_types() if transform else {})
self._allow_unknown_tags = (
not tags if allow_unknown_tags is None else allow_unknown_tags)
# The ApplyPTransform instance for the application of the multi FlatMap
@@ -322,7 +324,7 @@ class DoOutputsTuple(object):
pcoll = PCollection(
self._pipeline,
tag=tag,
- element_type=typehints.Any,
+ element_type=self._tagged_output_types.get(tag, typehints.Any),
is_bounded=is_bounded)
# Transfer the producer from the DoOutputsTuple to the resulting
# PCollection.
@@ -342,7 +344,11 @@ class DoOutputsTuple(object):
return pcoll
-class TaggedOutput(object):
+TagType = TypeVar('TagType', bound=str)
+ValueType = TypeVar('ValueType')
+
+
+class TaggedOutput(Generic[TagType, ValueType]):
"""An object representing a tagged value.
ParDo, Map, and FlatMap transforms can emit values on multiple outputs which
@@ -350,7 +356,7 @@ class TaggedOutput(object):
if it wants to emit on the main output and TaggedOutput objects
if it wants to emit a value on a specific tagged output.
"""
- def __init__(self, tag: str, value: Any) -> None:
+ def __init__(self, tag: TagType, value: ValueType) -> None:
if not isinstance(tag, str):
raise TypeError(
'Attempting to create a TaggedOutput with non-string tag %s' %
diff --git a/sdks/python/apache_beam/transforms/core.py
b/sdks/python/apache_beam/transforms/core.py
index 128a070e2ac..6d2552a2a6a 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -824,6 +824,7 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
process_type_hints = process_type_hints.strip_iterable()
except ValueError as e:
raise ValueError('Return value not iterable: %s: %s' % (self, e))
+ process_type_hints = process_type_hints.extract_tagged_outputs()
# Prefer class decorator type hints for backwards compatibility.
return get_type_hints(self.__class__).with_defaults(process_type_hints)
@@ -1039,6 +1040,7 @@ class CallableWrapperDoFn(DoFn):
raise TypeCheckError(
'Return value not iterable: %s: %s' %
(self.display_data()['fn'].value, e))
+ type_hints = type_hints.extract_tagged_outputs()
return type_hints
def infer_output_type(self, input_type):
@@ -1834,6 +1836,17 @@ class ParDo(PTransformWithSideInputs):
raise ValueError(
'Main output tag %r must be different from side output tags %r.' %
(main, tags))
+ type_hints = self.get_type_hints()
+ declared_tags = set(type_hints.tagged_output_types().keys())
+ requested_tags = set(tags)
+
+ unknown = requested_tags - declared_tags
+ if unknown and declared_tags: # Only warn if type hints exist
+ logging.warning(
+ "Tags %s requested in with_outputs() but not declared "
+ "in type hints. Declared tags: %s",
+ unknown,
+ declared_tags)
return _MultiParDo(self, tags, main, allow_unknown_tags)
def _do_fn_info(self):
@@ -2120,8 +2133,14 @@ def Map(fn, *args, **kwargs): # pylint:
disable=invalid-name
wrapper)
output_hint = type_hints.simple_output_type(label)
if output_hint:
+ tagged = {
+ k: typehints.Iterable[v]
+ for k, v in type_hints.tagged_output_types().items()
+ }
wrapper = with_output_types(
- typehints.Iterable[_strip_output_annotations(output_hint)])(
+ typehints.Iterable[_strip_output_annotations(
+ output_hint, strip_tagged_output=False)],
+ **tagged)(
wrapper)
# pylint: disable=protected-access
wrapper._argspec_fn = fn
@@ -2189,8 +2208,14 @@ def MapTuple(fn, *args, **kwargs): # pylint:
disable=invalid-name
pass
output_hint = type_hints.simple_output_type(label)
if output_hint:
+ tagged = {
+ k: typehints.Iterable[v]
+ for k, v in type_hints.tagged_output_types().items()
+ }
wrapper = with_output_types(
- typehints.Iterable[_strip_output_annotations(output_hint)])(
+ typehints.Iterable[_strip_output_annotations(
+ output_hint, strip_tagged_output=False)],
+ **tagged)(
wrapper)
# Replace the first (args) component.
@@ -2261,7 +2286,10 @@ def FlatMapTuple(fn, *args, **kwargs): # pylint:
disable=invalid-name
pass
output_hint = type_hints.simple_output_type(label)
if output_hint:
- wrapper =
with_output_types(_strip_output_annotations(output_hint))(wrapper)
+ wrapper = with_output_types(
+ _strip_output_annotations(output_hint, strip_tagged_output=False),
+ **type_hints.tagged_output_types())(
+ wrapper)
# Replace the first (args) component.
modified_arg_names = ['tuple_element'] + arg_names[-num_defaults:]
@@ -4222,12 +4250,15 @@ class Impulse(PTransform):
return Impulse()
-def _strip_output_annotations(type_hint):
+def _strip_output_annotations(type_hint, strip_tagged_output=True):
# TODO(robertwb): These should be parameterized types that the
# type inferencer understands.
# Then we can replace them with the correct element types instead of
# using Any. Refer to typehints.WindowedValue when doing this.
- annotations = (TimestampedValue, WindowedValue, pvalue.TaggedOutput)
+ annotations = [TimestampedValue, WindowedValue]
+ if strip_tagged_output:
+ annotations.append(pvalue.TaggedOutput)
+ annotations = tuple(annotations)
contains_annotation = False
diff --git a/sdks/python/apache_beam/transforms/ptransform.py
b/sdks/python/apache_beam/transforms/ptransform.py
index 94e9a0644d0..d5985b6212d 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -414,12 +414,15 @@ class PTransform(WithTypeHints, HasDisplayData,
Generic[InputT, OutputT]):
input_type_hint, 'Type hints for a PTransform')
return super().with_input_types(input_type_hint)
- def with_output_types(self, type_hint):
+ def with_output_types(self, type_hint, **tagged_type_hints):
"""Annotates the output type of a :class:`PTransform` with a type-hint.
Args:
type_hint (type): An instance of an allowed built-in type, a custom
class,
- or a :class:`~apache_beam.typehints.typehints.TypeConstraint`.
+ or a :class:`~apache_beam.typehints.typehints.TypeConstraint`. This is
+ the type hint for the main output.
+ **tagged_type_hints: Type hints for tagged outputs. Each keyword argument
+ specifies the type for a tagged output e.g., ``errors=str``.
Raises:
TypeError: If **type_hint** is not a valid type-hint. See
@@ -430,10 +433,22 @@ class PTransform(WithTypeHints, HasDisplayData,
Generic[InputT, OutputT]):
PTransform: A reference to the instance of this particular
:class:`PTransform` object. This allows chaining type-hinting related
methods.
+
+ Example::
+ result = pcoll | beam.ParDo(MyDoFn()).with_output_types(
+ int, # main output type
+ errors=str, # 'errors' tagged output type
+ warnings=str # 'warnings' tagged output type
+ ).with_outputs('errors', 'warnings', main='main')
"""
type_hint = native_type_compatibility.convert_to_beam_type(type_hint)
validate_composite_type_param(type_hint, 'Type hints for a PTransform')
- return super().with_output_types(type_hint)
+ for tag, hint in tagged_type_hints.items():
+ tagged_type_hints[tag] = native_type_compatibility.convert_to_beam_type(
+ hint)
+ validate_composite_type_param(
+ tagged_type_hints[tag], f'Tagged output type hint for {tag!r}')
+ return super().with_output_types(type_hint, **tagged_type_hints)
def with_resource_hints(self, **kwargs): # type: (...) -> PTransform
"""Adds resource hints to the :class:`PTransform`.
@@ -479,10 +494,11 @@ class PTransform(WithTypeHints, HasDisplayData,
Generic[InputT, OutputT]):
if hints is None or not any(hints):
return
arg_hints, kwarg_hints = hints
- if arg_hints and kwarg_hints:
+ # Output types can have kwargs for tagged output types.
+ if arg_hints and kwarg_hints and input_or_output != 'output':
raise TypeCheckError(
- 'PTransform cannot have both positional and keyword type hints '
- 'without overriding %s._type_check_%s()' %
+ 'PTransform cannot have both positional and keyword input type hints'
+ ' without overriding %s._type_check_%s()' %
(self.__class__, input_or_output))
root_hint = (
arg_hints[0] if len(arg_hints) == 1 else arg_hints or kwarg_hints)
diff --git a/sdks/python/apache_beam/typehints/decorators.py
b/sdks/python/apache_beam/typehints/decorators.py
index 2d2f7981dd2..e393113c002 100644
--- a/sdks/python/apache_beam/typehints/decorators.py
+++ b/sdks/python/apache_beam/typehints/decorators.py
@@ -89,12 +89,16 @@ from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
+from typing import Literal
from typing import NamedTuple
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
+from typing import get_args
+from typing import get_origin
+from apache_beam.pvalue import TaggedOutput
from apache_beam.typehints import native_type_compatibility
from apache_beam.typehints import typehints
from apache_beam.typehints.native_type_compatibility import
convert_to_beam_type
@@ -180,6 +184,83 @@ def disable_type_annotations():
TRACEBACK_LIMIT = 5
+_NO_MAIN_TYPE = object()
+
+
+def _tag_and_type(t):
+ """Extract tag name and value type from TaggedOutput[Literal['tag'], Type].
+
+ Returns raw Python types - conversion to beam types happens in
+ _extract_output_types.
+ """
+ args = get_args(t)
+ if len(args) != 2:
+ raise TypeError(
+ f"TaggedOutput expects 2 type parameters, got {len(args)}: {t}")
+
+ literal_type, value_type = args
+
+ if get_origin(literal_type) is not Literal:
+ raise TypeError(
+ f"First type parameter of TaggedOutput must be Literal['tag_name'], "
+ f"got {literal_type}. Example: TaggedOutput[Literal['errors'], str]")
+
+ tag_string = get_args(literal_type)[0]
+ return tag_string, value_type
+
+
+def _extract_tagged_from_type(beam_type):
+ """Extract tagged output types from a Beam type (post-convert_to_beam_type).
+
+ Called after the Iterable wrapper has been removed.
+ At this point, the type has already been through convert_to_beam_type, so
+ unions are typehints.UnionConstraint (not typing.Union), but
+ TaggedOutput[Literal['tag'], T] passes through unchanged as a typing
+ generic alias.
+
+ Returns:
+ (clean_type, tagged_dict) where clean_type is the type without TaggedOutput
+ members (or _NO_MAIN_TYPE if no main type), and tagged_dict maps tag names
+ to their Beam types.
+ """
+ # Single TaggedOutput[Literal['tag'], Type]
+ if get_origin(beam_type) is TaggedOutput:
+ tag, typ = _tag_and_type(beam_type)
+ return _NO_MAIN_TYPE, {tag: convert_to_beam_type(typ)}
+
+ # Bare TaggedOutput (unparameterized)
+ if beam_type is TaggedOutput:
+ logging.warning(
+ "TaggedOutput in return type must include type parameters: "
+ "TaggedOutput[Literal['tag_name'], ValueType]. "
+ "Bare TaggedOutput will be ignored.")
+ return _NO_MAIN_TYPE, {}
+
+ if not isinstance(beam_type, typehints.UnionHint.UnionConstraint):
+ return beam_type, {}
+
+ # UnionConstraint containing TaggedOutput members
+ main_types = []
+ tagged = {}
+ for member in beam_type.union_types:
+ if get_origin(member) is TaggedOutput:
+ tag, typ = _tag_and_type(member)
+ tagged[tag] = convert_to_beam_type(typ)
+ elif member is TaggedOutput:
+ logging.warning(
+ "TaggedOutput in return type must include type parameters: "
+ "TaggedOutput[Literal['tag_name'], ValueType]. "
+ "Bare TaggedOutput will be ignored.")
+ else:
+ main_types.append(member)
+ if not tagged and len(main_types) == len(beam_type.union_types):
+ return beam_type, {}
+ if not main_types:
+ return _NO_MAIN_TYPE, tagged
+ elif len(main_types) == 1:
+ return main_types[0], tagged
+ else:
+ return typehints.Union[tuple(main_types)], tagged
class IOTypeHints(NamedTuple):
@@ -273,6 +354,7 @@ class IOTypeHints(NamedTuple):
param.VAR_POSITIONAL], \
'Unsupported Parameter kind: %s' % param.kind
input_args.append(convert_to_beam_type(param.annotation))
+
output_args = []
if signature.return_annotation != signature.empty:
output_args.append(convert_to_beam_type(signature.return_annotation))
@@ -308,18 +390,24 @@ class IOTypeHints(NamedTuple):
def simple_output_type(self, context):
if self._has_output_types():
- args, kwargs = self.output_types
- if len(args) != 1 or kwargs:
+ args, _ = self.output_types
+ # Note: kwargs may contain tagged output types, which are ignored here.
+ # Use tagged_output_types() to access those.
+ if len(args) != 1:
raise TypeError(
'Expected single output type hint for %s but got: %s' %
(context, self.output_types))
return args[0]
+ def tagged_output_types(self):
+ if not self._has_output_types():
+ return {}
+ _, tagged_output_types = self.output_types
+ return tagged_output_types
+
def has_simple_output_type(self):
"""Whether there's a single positional output type."""
- return (
- self.output_types and len(self.output_types[0]) == 1 and
- not self.output_types[1])
+ return (self.output_types and len(self.output_types[0]) == 1)
def strip_pcoll(self):
from apache_beam.pipeline import Pipeline
@@ -413,6 +501,7 @@ class IOTypeHints(NamedTuple):
if self.output_types is None or not self.has_simple_output_type():
return self
output_type = self.output_types[0][0]
+ tagged_output_types = self.output_types[1]
if output_type is None or isinstance(output_type, type(None)):
return self
# If output_type == Optional[T]: output_type = T.
@@ -427,14 +516,51 @@ class IOTypeHints(NamedTuple):
if isinstance(output_type, typehints.TypeVariable):
# We don't know what T yields, so we just assume Any.
return self._replace(
- output_types=((typehints.Any, ), {}),
+ output_types=((typehints.Any, ), tagged_output_types),
origin=self._make_origin([self], tb=False, msg=['strip_iterable()']))
yielded_type = typehints.get_yielded_type(output_type)
+
+ # Also strip Iterable from tagged output types (e.g. from Map/MapTuple
+ # which wrap both main and tagged types in Iterable).
+ stripped_tags = {
+ tag: typehints.get_yielded_type(hint)
+ for tag, hint in tagged_output_types.items()
+ }
+
return self._replace(
- output_types=((yielded_type, ), {}),
+ output_types=((yielded_type, ), stripped_tags),
origin=self._make_origin([self], tb=False, msg=['strip_iterable()']))
+ def extract_tagged_outputs(self):
+ """Extract TaggedOutput types from the main output type into kwargs.
+
+ For annotation style (e.g. -> Iterable[int | TaggedOutput[...]]),
+ TaggedOutput stays embedded in the main type through convert_to_beam_type
+ and strip_iterable. This method extracts those TaggedOutput members into
+ the tagged output kwargs dict.
+
+ Should be called after strip_iterable().
+
+ Returns:
+ A copy of this instance with TaggedOutput members moved from the main
+ output type into the output kwargs dict.
+ """
+ if self.output_types is None or not self.has_simple_output_type():
+ return self
+ output_type = self.output_types[0][0]
+
+ clean_type, extracted_tags = _extract_tagged_from_type(output_type)
+ if not extracted_tags:
+ return self
+ if clean_type is _NO_MAIN_TYPE:
+ clean_type = typehints.Any
+ return self._replace(
+ output_types=((clean_type, ), extracted_tags),
+ origin=self._make_origin([self],
+ tb=False,
+ msg=['extract_tagged_outputs()']))
+
def with_defaults(self, hints: Optional['IOTypeHints']) -> 'IOTypeHints':
if not hints:
return self
@@ -782,7 +908,7 @@ def with_input_types(*positional_hints: Any,
def with_output_types(*return_type_hint: Any,
- **kwargs: Any) -> Callable[[T], T]:
+ **tagged_type_hints: Any) -> Callable[[T], T]:
"""A decorator that type-checks defined type-hints for return values(s).
This decorator will type-check the return value(s) of the decorated function.
@@ -822,18 +948,34 @@ def with_output_types(*return_type_hint: Any,
def negate(p):
return not p if p else p
+ For DoFns with tagged outputs, you can specify type hints for each tag:
+
+ .. testcode::
+ from apache_beam.typehints import with_input_types, with_output_types
+ @with_output_types(int, errors=str, warnings=str)
+ class MyDoFn(beam.DoFn):
+ def process(self, element):
+ if element < 0:
+ yield beam.pvalue.TaggedOutput('errors', 'Negative value')
+ elif element == 0:
+ yield beam.pvalue.TaggedOutput('warnings', 'Zero value')
+ else:
+ yield element
+
Args:
*return_type_hint: A type-hint specifying the proper return type of the
function. This argument should either be a built-in Python type or an
instance of a :class:`~apache_beam.typehints.typehints.TypeConstraint`
created by 'indexing' a
:class:`~apache_beam.typehints.typehints.CompositeTypeHint`.
- **kwargs: Not used.
+ **tagged_type_hints: Type hints for tagged outputs. Each keyword argument
+ specifies the type for a tagged output, e.g., ``errors=str``.
+
Raises:
- :class:`ValueError`: If any kwarg parameters are passed in,
- or the length of **return_type_hint** is greater than ``1``. Or if the
- inner wrapper function isn't passed a function object.
+ :class:`ValueError`: If the length of **return_type_hint** is greater
+ than ``1``. Or if the inner wrapper function isn't passed a function
+ object.
:class:`TypeCheckError`: If the **return_type_hint** object is
in invalid type-hint.
@@ -841,11 +983,6 @@ def with_output_types(*return_type_hint: Any,
The original function decorated such that it enforces type-hint constraints
for all return values.
"""
- if kwargs:
- raise ValueError(
- "All arguments for the 'returns' decorator must be "
- "positional arguments.")
-
if len(return_type_hint) != 1:
raise ValueError(
"'returns' accepts only a single positional argument. In "
@@ -854,13 +991,20 @@ def with_output_types(*return_type_hint: Any,
return_type_hint = native_type_compatibility.convert_to_beam_type(
return_type_hint[0])
-
validate_composite_type_param(
return_type_hint, error_msg_prefix='All type hint arguments')
+ converted_tag_hints = {}
+ for tag, hint in tagged_type_hints.items():
+ converted_hint = native_type_compatibility.convert_to_beam_type(hint)
+ validate_composite_type_param(
+ converted_hint, 'Tagged output type hint for %r' % tag)
+ converted_tag_hints[tag] = converted_hint
+
def annotate_output_types(f):
th = getattr(f, '_type_hints', IOTypeHints.empty())
- f._type_hints = th.with_output_types(return_type_hint) # pylint:
disable=protected-access
+ f._type_hints = th.with_output_types( # pylint: disable=protected-access
+ return_type_hint, **converted_tag_hints)
return f
return annotate_output_types
diff --git a/sdks/python/apache_beam/typehints/decorators_test.py
b/sdks/python/apache_beam/typehints/decorators_test.py
index a2909b4e545..95745f4e3d8 100644
--- a/sdks/python/apache_beam/typehints/decorators_test.py
+++ b/sdks/python/apache_beam/typehints/decorators_test.py
@@ -24,6 +24,7 @@ import typing
import unittest
from apache_beam import Map
+from apache_beam.pvalue import TaggedOutput
from apache_beam.typehints import Any
from apache_beam.typehints import Dict
from apache_beam.typehints import List
@@ -33,6 +34,7 @@ from apache_beam.typehints import TypeVariable
from apache_beam.typehints import WithTypeHints
from apache_beam.typehints import decorators
from apache_beam.typehints import typehints
+from apache_beam.typehints.native_type_compatibility import
convert_to_beam_type
T = TypeVariable('T')
# Name is 'T' so it converts to a beam type with the same name.
@@ -262,6 +264,63 @@ class IOTypeHintsTest(unittest.TestCase):
th = decorators.IOTypeHints.from_callable(fn)
self.assertRegex(th.debug_str(), r'unknown')
+ def test_from_callable_no_tagged_output(self):
+ def fn(x: int) -> str:
+ return str(x)
+
+ th = decorators.IOTypeHints.from_callable(fn)
+ self.assertEqual(th.input_types, ((int, ), {}))
+ self.assertEqual(th.output_types, ((str, ), {}))
+
+ def fn2(x: int) -> typing.Iterable[str]:
+ yield str(x)
+
+ th = decorators.IOTypeHints.from_callable(fn2)
+ self.assertEqual(th.input_types, ((int, ), {}))
+ self.assertEqual(th.output_types, ((typehints.Iterable[str], ), {}))
+
+ def test_from_callable_tagged_output_union(self):
+ """Tagged types are NOT extracted in from_callable. They stay embedded
+ in the main type and are extracted later in strip_iterable()."""
+ def fn(
+ x: int
+ ) -> int | str | TaggedOutput[typing.Literal['errors'], float
+ | str] | TaggedOutput[
+ typing.Literal['warnings'], str]:
+ return x
+
+ th = decorators.IOTypeHints.from_callable(fn)
+ self.assertEqual(th.input_types, ((int, ), {}))
+ # TaggedOutput members are preserved in the union no extraction yet.
+ output_type = th.output_types[0][0]
+ self.assertIsInstance(output_type, typehints.UnionConstraint)
+ self.assertEqual(th.output_types[1], {})
+
+ def test_from_callable_tagged_output_iterable(self):
+ """Tagged types inside Iterable are preserved until strip_iterable."""
+ def fn(
+ x: int
+ ) -> typing.Iterable[int | TaggedOutput[typing.Literal['errors'], str]]:
+ yield x
+
+ th = decorators.IOTypeHints.from_callable(fn)
+ self.assertEqual(th.input_types, ((int, ), {}))
+ # The full Iterable[Union[int, TaggedOutput[...]]] is preserved.
+ output_type = th.output_types[0][0]
+ self.assertIsInstance(output_type, typehints.IterableTypeConstraint)
+ self.assertEqual(th.output_types[1], {})
+
+ def test_from_callable_tagged_output_only(self):
+ """A standalone TaggedOutput annotation passes through from_callable."""
+ def fn(x: int) -> TaggedOutput[typing.Literal['errors'], str]:
+ pass
+
+ th = decorators.IOTypeHints.from_callable(fn)
+ self.assertEqual(th.input_types, ((int, ), {}))
+ # TaggedOutput[...] passes through convert_to_beam_type unchanged.
+ self.assertIs(typing.get_origin(th.output_types[0][0]), TaggedOutput)
+ self.assertEqual(th.output_types[1], {})
+
def test_getcallargs_forhints(self):
def fn(
a: int,
@@ -426,5 +485,71 @@ class DecoratorsTest(unittest.TestCase):
_ = ['a', 'b', 'c'] | Map(fn2) # Doesn't raise - no input type hints.
+class ExtractTaggedFromTypeTest(unittest.TestCase):
+ """Tests for _extract_tagged_from_type (Beam-level type extraction)."""
+ def test_simple_type_no_extraction(self):
+ main, tagged = decorators._extract_tagged_from_type(int)
+ self.assertEqual(main, int)
+ self.assertEqual(tagged, {})
+
+ def test_beam_union_no_tagged(self):
+ t = typehints.Union[int, str]
+ main, tagged = decorators._extract_tagged_from_type(t)
+ self.assertEqual(main, t)
+ self.assertEqual(tagged, {})
+
+ def test_standalone_tagged_output(self):
+ t = TaggedOutput[typing.Literal['errors'], str]
+ main, tagged = decorators._extract_tagged_from_type(t)
+ self.assertIs(main, decorators._NO_MAIN_TYPE)
+ self.assertEqual(tagged, {'errors': str})
+
+ def test_beam_union_with_tagged(self):
+ t = convert_to_beam_type(int | TaggedOutput[typing.Literal['errors'], str])
+ main, tagged = decorators._extract_tagged_from_type(t)
+ self.assertEqual(main, int)
+ self.assertEqual(tagged, {'errors': str})
+
+ def test_beam_union_multiple_tagged(self):
+ t = convert_to_beam_type(
+ int | TaggedOutput[typing.Literal['errors'], str]
+ | TaggedOutput[typing.Literal['warnings'], str])
+ main, tagged = decorators._extract_tagged_from_type(t)
+ self.assertEqual(main, int)
+ self.assertEqual(tagged, {'errors': str, 'warnings': str})
+
+ def test_beam_union_multiple_main_types(self):
+ t = convert_to_beam_type(
+ int | str | TaggedOutput[typing.Literal['errors'], bytes])
+ main, tagged = decorators._extract_tagged_from_type(t)
+ self.assertIsInstance(main, typehints.UnionConstraint)
+ self.assertIn(int, main.union_types)
+ self.assertIn(str, main.union_types)
+ self.assertEqual(tagged, {'errors': bytes})
+
+ def test_beam_union_tagged_only(self):
+ t = convert_to_beam_type(
+ TaggedOutput[typing.Literal['errors'], str]
+ | TaggedOutput[typing.Literal['warnings'], int])
+ main, tagged = decorators._extract_tagged_from_type(t)
+ self.assertIs(main, decorators._NO_MAIN_TYPE)
+ self.assertEqual(tagged, {'errors': str, 'warnings': int})
+
+ def test_bare_tagged_output_standalone(self):
+ with self.assertLogs(level='WARNING') as cm:
+ main, tagged = decorators._extract_tagged_from_type(TaggedOutput)
+ self.assertIn('Bare TaggedOutput will be ignored', cm.output[0])
+ self.assertIs(main, decorators._NO_MAIN_TYPE)
+ self.assertEqual(tagged, {})
+
+ def test_bare_tagged_output_in_union(self):
+ with self.assertLogs(level='WARNING') as cm:
+ t = convert_to_beam_type(str | TaggedOutput)
+ main, tagged = decorators._extract_tagged_from_type(t)
+ self.assertIn('Bare TaggedOutput will be ignored', cm.output[0])
+ self.assertEqual(main, str)
+ self.assertEqual(tagged, {})
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py
b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py
new file mode 100644
index 00000000000..c06f68fb88a
--- /dev/null
+++ b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py
@@ -0,0 +1,356 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for tagged output type hints.
+
+This tests the implementation of type hints for tagged outputs via three
styles:
+
+1. Decorator style:
+ @with_output_types(int, errors=str, warnings=str)
+ class MyDoFn(beam.DoFn):
+ ...
+
+2. Method chain style:
+ beam.ParDo(MyDoFn()).with_output_types(int, errors=str)
+
+3. Function annotation style:
+ def fn(element) -> int | TaggedOutput[Literal['errors'], str]:
+ ...
+"""
+
+# pytype: skip-file
+
+import unittest
+from typing import Iterable
+from typing import Literal
+from typing import Union
+
+import apache_beam as beam
+from apache_beam.pvalue import TaggedOutput
+from apache_beam.typehints import with_output_types
+from apache_beam.typehints.decorators import IOTypeHints
+
+
+class IOTypeHintsTaggedOutputTest(unittest.TestCase):
+ """Tests for IOTypeHints.tagged_output_types() accessor."""
+ def test_empty_hints_returns_empty_dict(self):
+ empty = IOTypeHints.empty()
+ self.assertEqual(empty.tagged_output_types(), {})
+
+ def test_with_tagged_types(self):
+ hints = IOTypeHints.empty().with_output_types(int, errors=str,
warnings=str)
+ self.assertEqual(
+ hints.tagged_output_types(), {
+ 'errors': str, 'warnings': str
+ })
+
+ def test_simple_output_type_with_tagged_types(self):
+ """simple_output_type() should still return main type when tags present."""
+ hints = IOTypeHints.empty().with_output_types(int, errors=str,
warnings=str)
+ self.assertEqual(hints.simple_output_type('test'), int)
+
+ hints = IOTypeHints.empty().with_output_types(
+ Union[int, str], errors=str, warnings=str)
+ self.assertEqual(hints.simple_output_type('test'), Union[int, str])
+
+ def test_without_tagged_types(self):
+ """Without tagged types, tagged_output_types() returns empty dict."""
+ hints = IOTypeHints.empty().with_output_types(int)
+ self.assertEqual(hints.tagged_output_types(), {})
+ self.assertEqual(hints.simple_output_type('test'), int)
+
+
+class DecoratorStyleTaggedOutputTest(unittest.TestCase):
+ """Tests for @with_output_types decorator style across all transforms."""
+ def test_pardo_decorator_pipeline(self):
+ """Test that tagged types propagate through ParDo pipeline."""
+ @with_output_types(int, errors=str)
+ class MyDoFn(beam.DoFn):
+ def process(self, element):
+ if element < 0:
+ yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+ else:
+ yield element * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([-1, 0, 1, 2])
+ | beam.ParDo(MyDoFn()).with_outputs('errors', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+
+ def test_map_decorator_pipeline(self):
+ """Test that tagged types propagate through Map."""
+ @with_output_types(int, errors=str)
+ def mapfn(element):
+ if element < 0:
+ return beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+ else:
+ return element * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([-1, 0, 1, 2])
+ | beam.Map(mapfn).with_outputs('errors', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+
+ def test_flatmap_decorator_pipeline(self):
+ """Test that tagged types propagate through FlatMap."""
+ @with_output_types(Iterable[int], errors=Iterable[str])
+ def flatmapfn(element):
+ if element < 0:
+ yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+ else:
+ yield element * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([-1, 0, 1, 2])
+ | beam.FlatMap(flatmapfn).with_outputs('errors', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+
+ def test_maptuple_decorator_pipeline(self):
+ """Test that tagged types propagate through MapTuple."""
+ @with_output_types(int, errors=str)
+ def maptuplefn(key, value):
+ if value < 0:
+ return beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}')
+ else:
+ return value * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([('a', -1), ('b', 2), ('c', 3)])
+ | beam.MapTuple(maptuplefn).with_outputs('errors', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+
+ def test_flatmaptuple_decorator_pipeline(self):
+ """Test that tagged types propagate through FlatMapTuple."""
+ @with_output_types(Iterable[int], errors=Iterable[str])
+ def flatmaptuplefn(key, value):
+ if value < 0:
+ yield beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}')
+ else:
+ yield value * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([('a', -1), ('b', 2), ('c', 3)])
+ | beam.FlatMapTuple(flatmaptuplefn).with_outputs(
+ 'errors', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+
+
+class ChainStyleTaggedOutputTest(unittest.TestCase):
+ """Tests for .with_output_types() method chain style across all
transforms."""
+ def test_pardo_chain_pipeline(self):
+ """Test ParDo with chained type hints."""
+ class SimpleDoFn(beam.DoFn):
+ def process(self, element):
+ if element < 0:
+ yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+ else:
+ yield element * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([-1, 0, 1, 2])
+ | beam.ParDo(SimpleDoFn()).with_output_types(
+ int, errors=str).with_outputs('errors', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+
+ def test_map_chain_pipeline(self):
+ """Test Map with chained type hints."""
+ def mapfn(element):
+ if element < 0:
+ return beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+ else:
+ return element * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([-1, 0, 1, 2])
+ | beam.Map(mapfn).with_output_types(int, errors=str).with_outputs(
+ 'errors', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+
+ def test_flatmap_chain_pipeline(self):
+ """Test FlatMap with chained type hints.
+
+ Note: For FlatMap.with_output_types(), specify the element type directly
+ (int), not wrapped in Iterable. The transform handles iteration internally.
+ """
+ def flatmapfn(element):
+ if element < 0:
+ yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+ else:
+ yield element * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([-1, 0, 1, 2])
+ | beam.FlatMap(flatmapfn).with_output_types(
+ int, errors=str).with_outputs('errors', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+
+ def test_maptuple_chain_pipeline(self):
+ """Test MapTuple with chained type hints."""
+ def maptuplefn(key, value):
+ if value < 0:
+ return beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}')
+ else:
+ return value * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([('a', -1), ('b', 2), ('c', 3)])
+ | beam.MapTuple(maptuplefn).with_output_types(
+ int, errors=str).with_outputs('errors', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+
+ def test_flatmaptuple_chain_pipeline(self):
+ """Test FlatMapTuple with chained type hints.
+
+ Note: For FlatMapTuple.with_output_types(), specify the element type
+ directly (int), not wrapped in Iterable.
+ """
+ def flatmaptuplefn(key, value):
+ if value < 0:
+ yield beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}')
+ else:
+ yield value * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([('a', -1), ('b', 2), ('c', 3)])
+ | beam.FlatMapTuple(flatmaptuplefn).with_output_types(
+ int, errors=str).with_outputs('errors', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+
+
+class AnnotationStyleTaggedOutputTest(unittest.TestCase):
+ """Tests for function annotation style across all transforms."""
+ def test_map_annotation_union(self):
+ """Test Map with Union[int, TaggedOutput[...]] annotation."""
+ def mapfn(element: int) -> int | TaggedOutput[Literal['errors'], str]:
+ if element < 0:
+ return beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+ else:
+ return element * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([-1, 0, 1, 2])
+ | beam.Map(mapfn).with_outputs('errors', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+
+ def test_map_annotation_multiple_tags(self):
+ """Test Map with multiple TaggedOutput types in annotation."""
+ def mapfn(
+ element: int
+ ) -> int | TaggedOutput[Literal['errors'],
+ str] | TaggedOutput[Literal['warnings'], str]:
+ if element < 0:
+ return beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+ elif element == 0:
+ return beam.pvalue.TaggedOutput('warnings', 'Zero value')
+ else:
+ return element * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([-1, 0, 1, 2])
+ | beam.Map(mapfn).with_outputs('errors', 'warnings', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+ self.assertEqual(results.warnings.element_type, str)
+
+ def test_flatmap_annotation_iterable(self):
+ """Test FlatMap with Iterable[int | TaggedOutput[...]] annotation."""
+ def flatmapfn(
+ element: int) -> Iterable[int | TaggedOutput[Literal['errors'], str]]:
+ if element < 0:
+ yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+ else:
+ yield element * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([-1, 0, 1, 2])
+ | beam.FlatMap(flatmapfn).with_outputs('errors', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+
+ def test_pardo_annotation_process_method(self):
+ """Test DoFn with process method annotation."""
+ class AnnotatedDoFn(beam.DoFn):
+ def process(
+ self,
+ element: int) -> Iterable[int | TaggedOutput[Literal['errors'],
str]]:
+ if element < 0:
+ yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+ else:
+ yield element * 2
+
+ with beam.Pipeline() as p:
+ results = (
+ p
+ | beam.Create([-1, 0, 1, 2])
+ | beam.ParDo(AnnotatedDoFn()).with_outputs('errors', main='main'))
+
+ self.assertEqual(results.main.element_type, int)
+ self.assertEqual(results.errors.element_type, str)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/typehints/typehints_test.py
b/sdks/python/apache_beam/typehints/typehints_test.py
index 0bbc21f6739..cec83038008 100644
--- a/sdks/python/apache_beam/typehints/typehints_test.py
+++ b/sdks/python/apache_beam/typehints/typehints_test.py
@@ -1421,7 +1421,7 @@ class OutputDecoratorTestCase(TypeHintTestCase):
return 5, 'bar'
def test_no_kwargs_accepted(self):
- with self.assertRaisesRegex(ValueError, r'must be positional'):
+ with self.assertRaisesRegex(ValueError, r'single positional argument'):
@with_output_types(m=int)
def unused_foo():