This is an automated email from the ASF dual-hosted git repository.

cvandermerwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new fed1b551c35 Add support for tagged output type hints. (#37434)
fed1b551c35 is described below

commit fed1b551c353b588f764a293b09a55cce997859a
Author: claudevdm <[email protected]>
AuthorDate: Wed Feb 4 17:14:07 2026 -0500

    Add support for tagged output type hints. (#37434)
    
    * Add tagged typehint support.
    
    * Just warn when bare tagged output
    
    * Remove contains tagged output check.
    
    * Mapped bare TaggedOutput to Any
    
    * Extract tagged outputs after strip_iterable.
---
 sdks/python/apache_beam/pvalue.py                  |  12 +-
 sdks/python/apache_beam/transforms/core.py         |  41 ++-
 sdks/python/apache_beam/transforms/ptransform.py   |  28 +-
 sdks/python/apache_beam/typehints/decorators.py    | 182 +++++++++--
 .../apache_beam/typehints/decorators_test.py       | 125 ++++++++
 .../typehints/tagged_output_typehints_test.py      | 356 +++++++++++++++++++++
 .../python/apache_beam/typehints/typehints_test.py |   2 +-
 7 files changed, 712 insertions(+), 34 deletions(-)

diff --git a/sdks/python/apache_beam/pvalue.py 
b/sdks/python/apache_beam/pvalue.py
index 6621d96127d..d09a0040bd7 100644
--- a/sdks/python/apache_beam/pvalue.py
+++ b/sdks/python/apache_beam/pvalue.py
@@ -265,6 +265,8 @@ class DoOutputsTuple(object):
     self._tags = tags
     self._main_tag = main_tag
     self._transform = transform
+    self._tagged_output_types = (
+        transform.get_type_hints().tagged_output_types() if transform else {})
     self._allow_unknown_tags = (
         not tags if allow_unknown_tags is None else allow_unknown_tags)
     # The ApplyPTransform instance for the application of the multi FlatMap
@@ -322,7 +324,7 @@ class DoOutputsTuple(object):
       pcoll = PCollection(
           self._pipeline,
           tag=tag,
-          element_type=typehints.Any,
+          element_type=self._tagged_output_types.get(tag, typehints.Any),
           is_bounded=is_bounded)
       # Transfer the producer from the DoOutputsTuple to the resulting
       # PCollection.
@@ -342,7 +344,11 @@ class DoOutputsTuple(object):
     return pcoll
 
 
-class TaggedOutput(object):
+TagType = TypeVar('TagType', bound=str)
+ValueType = TypeVar('ValueType')
+
+
+class TaggedOutput(Generic[TagType, ValueType]):
   """An object representing a tagged value.
 
   ParDo, Map, and FlatMap transforms can emit values on multiple outputs which
@@ -350,7 +356,7 @@ class TaggedOutput(object):
   if it wants to emit on the main output and TaggedOutput objects
   if it wants to emit a value on a specific tagged output.
   """
-  def __init__(self, tag: str, value: Any) -> None:
+  def __init__(self, tag: TagType, value: ValueType) -> None:
     if not isinstance(tag, str):
       raise TypeError(
           'Attempting to create a TaggedOutput with non-string tag %s' %
diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index 128a070e2ac..6d2552a2a6a 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -824,6 +824,7 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
       process_type_hints = process_type_hints.strip_iterable()
     except ValueError as e:
       raise ValueError('Return value not iterable: %s: %s' % (self, e))
+    process_type_hints = process_type_hints.extract_tagged_outputs()
 
     # Prefer class decorator type hints for backwards compatibility.
     return get_type_hints(self.__class__).with_defaults(process_type_hints)
@@ -1039,6 +1040,7 @@ class CallableWrapperDoFn(DoFn):
       raise TypeCheckError(
           'Return value not iterable: %s: %s' %
           (self.display_data()['fn'].value, e))
+    type_hints = type_hints.extract_tagged_outputs()
     return type_hints
 
   def infer_output_type(self, input_type):
@@ -1834,6 +1836,17 @@ class ParDo(PTransformWithSideInputs):
       raise ValueError(
           'Main output tag %r must be different from side output tags %r.' %
           (main, tags))
+    type_hints = self.get_type_hints()
+    declared_tags = set(type_hints.tagged_output_types().keys())
+    requested_tags = set(tags)
+
+    unknown = requested_tags - declared_tags
+    if unknown and declared_tags:  # Only warn if type hints exist
+      logging.warning(
+          "Tags %s requested in with_outputs() but not declared "
+          "in type hints. Declared tags: %s",
+          unknown,
+          declared_tags)
     return _MultiParDo(self, tags, main, allow_unknown_tags)
 
   def _do_fn_info(self):
@@ -2120,8 +2133,14 @@ def Map(fn, *args, **kwargs):  # pylint: 
disable=invalid-name
             wrapper)
   output_hint = type_hints.simple_output_type(label)
   if output_hint:
+    tagged = {
+        k: typehints.Iterable[v]
+        for k, v in type_hints.tagged_output_types().items()
+    }
     wrapper = with_output_types(
-        typehints.Iterable[_strip_output_annotations(output_hint)])(
+        typehints.Iterable[_strip_output_annotations(
+            output_hint, strip_tagged_output=False)],
+        **tagged)(
             wrapper)
   # pylint: disable=protected-access
   wrapper._argspec_fn = fn
@@ -2189,8 +2208,14 @@ def MapTuple(fn, *args, **kwargs):  # pylint: 
disable=invalid-name
     pass
   output_hint = type_hints.simple_output_type(label)
   if output_hint:
+    tagged = {
+        k: typehints.Iterable[v]
+        for k, v in type_hints.tagged_output_types().items()
+    }
     wrapper = with_output_types(
-        typehints.Iterable[_strip_output_annotations(output_hint)])(
+        typehints.Iterable[_strip_output_annotations(
+            output_hint, strip_tagged_output=False)],
+        **tagged)(
             wrapper)
 
   # Replace the first (args) component.
@@ -2261,7 +2286,10 @@ def FlatMapTuple(fn, *args, **kwargs):  # pylint: 
disable=invalid-name
     pass
   output_hint = type_hints.simple_output_type(label)
   if output_hint:
-    wrapper = 
with_output_types(_strip_output_annotations(output_hint))(wrapper)
+    wrapper = with_output_types(
+        _strip_output_annotations(output_hint, strip_tagged_output=False),
+        **type_hints.tagged_output_types())(
+            wrapper)
 
   # Replace the first (args) component.
   modified_arg_names = ['tuple_element'] + arg_names[-num_defaults:]
@@ -4222,12 +4250,15 @@ class Impulse(PTransform):
     return Impulse()
 
 
-def _strip_output_annotations(type_hint):
+def _strip_output_annotations(type_hint, strip_tagged_output=True):
   # TODO(robertwb): These should be parameterized types that the
   # type inferencer understands.
   # Then we can replace them with the correct element types instead of
   # using Any. Refer to typehints.WindowedValue when doing this.
-  annotations = (TimestampedValue, WindowedValue, pvalue.TaggedOutput)
+  annotations = [TimestampedValue, WindowedValue]
+  if strip_tagged_output:
+    annotations.append(pvalue.TaggedOutput)
+  annotations = tuple(annotations)
 
   contains_annotation = False
 
diff --git a/sdks/python/apache_beam/transforms/ptransform.py 
b/sdks/python/apache_beam/transforms/ptransform.py
index 94e9a0644d0..d5985b6212d 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -414,12 +414,15 @@ class PTransform(WithTypeHints, HasDisplayData, 
Generic[InputT, OutputT]):
         input_type_hint, 'Type hints for a PTransform')
     return super().with_input_types(input_type_hint)
 
-  def with_output_types(self, type_hint):
+  def with_output_types(self, type_hint, **tagged_type_hints):
     """Annotates the output type of a :class:`PTransform` with a type-hint.
 
     Args:
       type_hint (type): An instance of an allowed built-in type, a custom 
class,
-        or a :class:`~apache_beam.typehints.typehints.TypeConstraint`.
+        or a :class:`~apache_beam.typehints.typehints.TypeConstraint`. This is
+        the type hint for the main output.
+      **tagged_type_hints: Type hints for tagged outputs. Each keyword argument
+        specifies the type for a tagged output e.g., ``errors=str``.
 
     Raises:
       TypeError: If **type_hint** is not a valid type-hint. See
@@ -430,10 +433,22 @@ class PTransform(WithTypeHints, HasDisplayData, 
Generic[InputT, OutputT]):
       PTransform: A reference to the instance of this particular
       :class:`PTransform` object. This allows chaining type-hinting related
       methods.
+
+    Example::
+      result = pcoll | beam.ParDo(MyDoFn()).with_output_types(
+          int,  # main output type
+          errors=str,  # 'errors' tagged output type
+          warnings=str  # 'warnings' tagged output type
+      ).with_outputs('errors', 'warnings', main='main')
     """
     type_hint = native_type_compatibility.convert_to_beam_type(type_hint)
     validate_composite_type_param(type_hint, 'Type hints for a PTransform')
-    return super().with_output_types(type_hint)
+    for tag, hint in tagged_type_hints.items():
+      tagged_type_hints[tag] = native_type_compatibility.convert_to_beam_type(
+          hint)
+      validate_composite_type_param(
+          tagged_type_hints[tag], f'Tagged output type hint for {tag!r}')
+    return super().with_output_types(type_hint, **tagged_type_hints)
 
   def with_resource_hints(self, **kwargs):  # type: (...) -> PTransform
     """Adds resource hints to the :class:`PTransform`.
@@ -479,10 +494,11 @@ class PTransform(WithTypeHints, HasDisplayData, 
Generic[InputT, OutputT]):
     if hints is None or not any(hints):
       return
     arg_hints, kwarg_hints = hints
-    if arg_hints and kwarg_hints:
+    # Output types can have kwargs for tagged output types.
+    if arg_hints and kwarg_hints and input_or_output != 'output':
       raise TypeCheckError(
-          'PTransform cannot have both positional and keyword type hints '
-          'without overriding %s._type_check_%s()' %
+          'PTransform cannot have both positional and keyword input type hints'
+          ' without overriding %s._type_check_%s()' %
           (self.__class__, input_or_output))
     root_hint = (
         arg_hints[0] if len(arg_hints) == 1 else arg_hints or kwarg_hints)
diff --git a/sdks/python/apache_beam/typehints/decorators.py 
b/sdks/python/apache_beam/typehints/decorators.py
index 2d2f7981dd2..e393113c002 100644
--- a/sdks/python/apache_beam/typehints/decorators.py
+++ b/sdks/python/apache_beam/typehints/decorators.py
@@ -89,12 +89,16 @@ from typing import Callable
 from typing import Dict
 from typing import Iterable
 from typing import List
+from typing import Literal
 from typing import NamedTuple
 from typing import Optional
 from typing import Tuple
 from typing import TypeVar
 from typing import Union
+from typing import get_args
+from typing import get_origin
 
+from apache_beam.pvalue import TaggedOutput
 from apache_beam.typehints import native_type_compatibility
 from apache_beam.typehints import typehints
 from apache_beam.typehints.native_type_compatibility import 
convert_to_beam_type
@@ -180,6 +184,83 @@ def disable_type_annotations():
 
 
 TRACEBACK_LIMIT = 5
+_NO_MAIN_TYPE = object()
+
+
+def _tag_and_type(t):
+  """Extract tag name and value type from TaggedOutput[Literal['tag'], Type].
+
+  Returns raw Python types - conversion to beam types happens in
+  _extract_output_types.
+  """
+  args = get_args(t)
+  if len(args) != 2:
+    raise TypeError(
+        f"TaggedOutput expects 2 type parameters, got {len(args)}: {t}")
+
+  literal_type, value_type = args
+
+  if get_origin(literal_type) is not Literal:
+    raise TypeError(
+        f"First type parameter of TaggedOutput must be Literal['tag_name'], "
+        f"got {literal_type}. Example: TaggedOutput[Literal['errors'], str]")
+
+  tag_string = get_args(literal_type)[0]
+  return tag_string, value_type
+
+
+def _extract_tagged_from_type(beam_type):
+  """Extract tagged output types from a Beam type (post-convert_to_beam_type).
+
+  Called after the Iterable wrapper has been removed.
+  At this point, the type has already been through convert_to_beam_type, so
+  unions are typehints.UnionConstraint (not typing.Union), but
+  TaggedOutput[Literal['tag'], T] passes through unchanged as a typing
+  generic alias.
+
+  Returns:
+    (clean_type, tagged_dict) where clean_type is the type without TaggedOutput
+    members (or _NO_MAIN_TYPE if no main type), and tagged_dict maps tag names
+    to their Beam types.
+  """
+  # Single TaggedOutput[Literal['tag'], Type]
+  if get_origin(beam_type) is TaggedOutput:
+    tag, typ = _tag_and_type(beam_type)
+    return _NO_MAIN_TYPE, {tag: convert_to_beam_type(typ)}
+
+  # Bare TaggedOutput (unparameterized)
+  if beam_type is TaggedOutput:
+    logging.warning(
+        "TaggedOutput in return type must include type parameters: "
+        "TaggedOutput[Literal['tag_name'], ValueType]. "
+        "Bare TaggedOutput will be ignored.")
+    return _NO_MAIN_TYPE, {}
+
+  if not isinstance(beam_type, typehints.UnionHint.UnionConstraint):
+    return beam_type, {}
+
+  # UnionConstraint containing TaggedOutput members
+  main_types = []
+  tagged = {}
+  for member in beam_type.union_types:
+    if get_origin(member) is TaggedOutput:
+      tag, typ = _tag_and_type(member)
+      tagged[tag] = convert_to_beam_type(typ)
+    elif member is TaggedOutput:
+      logging.warning(
+          "TaggedOutput in return type must include type parameters: "
+          "TaggedOutput[Literal['tag_name'], ValueType]. "
+          "Bare TaggedOutput will be ignored.")
+    else:
+      main_types.append(member)
+  if not tagged and len(main_types) == len(beam_type.union_types):
+    return beam_type, {}
+  if not main_types:
+    return _NO_MAIN_TYPE, tagged
+  elif len(main_types) == 1:
+    return main_types[0], tagged
+  else:
+    return typehints.Union[tuple(main_types)], tagged
 
 
 class IOTypeHints(NamedTuple):
@@ -273,6 +354,7 @@ class IOTypeHints(NamedTuple):
                                 param.VAR_POSITIONAL], \
               'Unsupported Parameter kind: %s' % param.kind
           input_args.append(convert_to_beam_type(param.annotation))
+
     output_args = []
     if signature.return_annotation != signature.empty:
       output_args.append(convert_to_beam_type(signature.return_annotation))
@@ -308,18 +390,24 @@ class IOTypeHints(NamedTuple):
 
   def simple_output_type(self, context):
     if self._has_output_types():
-      args, kwargs = self.output_types
-      if len(args) != 1 or kwargs:
+      args, _ = self.output_types
+      # Note: kwargs may contain tagged output types, which are ignored here.
+      # Use tagged_output_types() to access those.
+      if len(args) != 1:
         raise TypeError(
             'Expected single output type hint for %s but got: %s' %
             (context, self.output_types))
       return args[0]
 
+  def tagged_output_types(self):
+    if not self._has_output_types():
+      return {}
+    _, tagged_output_types = self.output_types
+    return tagged_output_types
+
   def has_simple_output_type(self):
     """Whether there's a single positional output type."""
-    return (
-        self.output_types and len(self.output_types[0]) == 1 and
-        not self.output_types[1])
+    return (self.output_types and len(self.output_types[0]) == 1)
 
   def strip_pcoll(self):
     from apache_beam.pipeline import Pipeline
@@ -413,6 +501,7 @@ class IOTypeHints(NamedTuple):
     if self.output_types is None or not self.has_simple_output_type():
       return self
     output_type = self.output_types[0][0]
+    tagged_output_types = self.output_types[1]
     if output_type is None or isinstance(output_type, type(None)):
       return self
     # If output_type == Optional[T]: output_type = T.
@@ -427,14 +516,51 @@ class IOTypeHints(NamedTuple):
     if isinstance(output_type, typehints.TypeVariable):
       # We don't know what T yields, so we just assume Any.
       return self._replace(
-          output_types=((typehints.Any, ), {}),
+          output_types=((typehints.Any, ), tagged_output_types),
           origin=self._make_origin([self], tb=False, msg=['strip_iterable()']))
 
     yielded_type = typehints.get_yielded_type(output_type)
+
+    # Also strip Iterable from tagged output types (e.g. from Map/MapTuple
+    # which wrap both main and tagged types in Iterable).
+    stripped_tags = {
+        tag: typehints.get_yielded_type(hint)
+        for tag, hint in tagged_output_types.items()
+    }
+
     return self._replace(
-        output_types=((yielded_type, ), {}),
+        output_types=((yielded_type, ), stripped_tags),
         origin=self._make_origin([self], tb=False, msg=['strip_iterable()']))
 
+  def extract_tagged_outputs(self):
+    """Extract TaggedOutput types from the main output type into kwargs.
+
+    For annotation style (e.g. -> Iterable[int | TaggedOutput[...]]),
+    TaggedOutput stays embedded in the main type through convert_to_beam_type
+    and strip_iterable. This method extracts those TaggedOutput members into
+    the tagged output kwargs dict.
+
+    Should be called after strip_iterable().
+
+    Returns:
+      A copy of this instance with TaggedOutput members moved from the main
+      output type into the output kwargs dict.
+    """
+    if self.output_types is None or not self.has_simple_output_type():
+      return self
+    output_type = self.output_types[0][0]
+
+    clean_type, extracted_tags = _extract_tagged_from_type(output_type)
+    if not extracted_tags:
+      return self
+    if clean_type is _NO_MAIN_TYPE:
+      clean_type = typehints.Any
+    return self._replace(
+        output_types=((clean_type, ), extracted_tags),
+        origin=self._make_origin([self],
+                                 tb=False,
+                                 msg=['extract_tagged_outputs()']))
+
   def with_defaults(self, hints: Optional['IOTypeHints']) -> 'IOTypeHints':
     if not hints:
       return self
@@ -782,7 +908,7 @@ def with_input_types(*positional_hints: Any,
 
 
 def with_output_types(*return_type_hint: Any,
-                      **kwargs: Any) -> Callable[[T], T]:
+                      **tagged_type_hints: Any) -> Callable[[T], T]:
   """A decorator that type-checks defined type-hints for return values(s).
 
   This decorator will type-check the return value(s) of the decorated function.
@@ -822,18 +948,34 @@ def with_output_types(*return_type_hint: Any,
     def negate(p):
       return not p if p else p
 
+  For DoFns with tagged outputs, you can specify type hints for each tag:
+
+  .. testcode::
+    from apache_beam.typehints import with_input_types, with_output_types
+    @with_output_types(int, errors=str, warnings=str)
+    class MyDoFn(beam.DoFn):
+      def process(self, element):
+        if element < 0:
+          yield beam.pvalue.TaggedOutput('errors', 'Negative value')
+        elif element == 0:
+          yield beam.pvalue.TaggedOutput('warnings', 'Zero value')
+        else:
+          yield element
+
   Args:
     *return_type_hint: A type-hint specifying the proper return type of the
       function. This argument should either be a built-in Python type or an
       instance of a :class:`~apache_beam.typehints.typehints.TypeConstraint`
       created by 'indexing' a
       :class:`~apache_beam.typehints.typehints.CompositeTypeHint`.
-    **kwargs: Not used.
+    **tagged_type_hints: Type hints for tagged outputs. Each keyword argument
+      specifies the type for a tagged output, e.g., ``errors=str``.
+
 
   Raises:
-    :class:`ValueError`: If any kwarg parameters are passed in,
-      or the length of **return_type_hint** is greater than ``1``. Or if the
-      inner wrapper function isn't passed a function object.
+    :class:`ValueError`: If the length of **return_type_hint** is greater
+      than ``1``. Or if the inner wrapper function isn't passed a function
+      object.
     :class:`TypeCheckError`: If the **return_type_hint** object is
       in invalid type-hint.
 
@@ -841,11 +983,6 @@ def with_output_types(*return_type_hint: Any,
     The original function decorated such that it enforces type-hint constraints
     for all return values.
   """
-  if kwargs:
-    raise ValueError(
-        "All arguments for the 'returns' decorator must be "
-        "positional arguments.")
-
   if len(return_type_hint) != 1:
     raise ValueError(
         "'returns' accepts only a single positional argument. In "
@@ -854,13 +991,20 @@ def with_output_types(*return_type_hint: Any,
 
   return_type_hint = native_type_compatibility.convert_to_beam_type(
       return_type_hint[0])
-
   validate_composite_type_param(
       return_type_hint, error_msg_prefix='All type hint arguments')
 
+  converted_tag_hints = {}
+  for tag, hint in tagged_type_hints.items():
+    converted_hint = native_type_compatibility.convert_to_beam_type(hint)
+    validate_composite_type_param(
+        converted_hint, 'Tagged output type hint for %r' % tag)
+    converted_tag_hints[tag] = converted_hint
+
   def annotate_output_types(f):
     th = getattr(f, '_type_hints', IOTypeHints.empty())
-    f._type_hints = th.with_output_types(return_type_hint)  # pylint: 
disable=protected-access
+    f._type_hints = th.with_output_types( # pylint: disable=protected-access
+        return_type_hint, **converted_tag_hints)
     return f
 
   return annotate_output_types
diff --git a/sdks/python/apache_beam/typehints/decorators_test.py 
b/sdks/python/apache_beam/typehints/decorators_test.py
index a2909b4e545..95745f4e3d8 100644
--- a/sdks/python/apache_beam/typehints/decorators_test.py
+++ b/sdks/python/apache_beam/typehints/decorators_test.py
@@ -24,6 +24,7 @@ import typing
 import unittest
 
 from apache_beam import Map
+from apache_beam.pvalue import TaggedOutput
 from apache_beam.typehints import Any
 from apache_beam.typehints import Dict
 from apache_beam.typehints import List
@@ -33,6 +34,7 @@ from apache_beam.typehints import TypeVariable
 from apache_beam.typehints import WithTypeHints
 from apache_beam.typehints import decorators
 from apache_beam.typehints import typehints
+from apache_beam.typehints.native_type_compatibility import 
convert_to_beam_type
 
 T = TypeVariable('T')
 # Name is 'T' so it converts to a beam type with the same name.
@@ -262,6 +264,63 @@ class IOTypeHintsTest(unittest.TestCase):
     th = decorators.IOTypeHints.from_callable(fn)
     self.assertRegex(th.debug_str(), r'unknown')
 
+  def test_from_callable_no_tagged_output(self):
+    def fn(x: int) -> str:
+      return str(x)
+
+    th = decorators.IOTypeHints.from_callable(fn)
+    self.assertEqual(th.input_types, ((int, ), {}))
+    self.assertEqual(th.output_types, ((str, ), {}))
+
+    def fn2(x: int) -> typing.Iterable[str]:
+      yield str(x)
+
+    th = decorators.IOTypeHints.from_callable(fn2)
+    self.assertEqual(th.input_types, ((int, ), {}))
+    self.assertEqual(th.output_types, ((typehints.Iterable[str], ), {}))
+
+  def test_from_callable_tagged_output_union(self):
+    """Tagged types are NOT extracted in from_callable. They stay embedded
+    in the main type and are extracted later in strip_iterable()."""
+    def fn(
+        x: int
+    ) -> int | str | TaggedOutput[typing.Literal['errors'], float
+                                  | str] | TaggedOutput[
+                                      typing.Literal['warnings'], str]:
+      return x
+
+    th = decorators.IOTypeHints.from_callable(fn)
+    self.assertEqual(th.input_types, ((int, ), {}))
+    # TaggedOutput members are preserved in the union  no extraction yet.
+    output_type = th.output_types[0][0]
+    self.assertIsInstance(output_type, typehints.UnionConstraint)
+    self.assertEqual(th.output_types[1], {})
+
+  def test_from_callable_tagged_output_iterable(self):
+    """Tagged types inside Iterable are preserved until strip_iterable."""
+    def fn(
+        x: int
+    ) -> typing.Iterable[int | TaggedOutput[typing.Literal['errors'], str]]:
+      yield x
+
+    th = decorators.IOTypeHints.from_callable(fn)
+    self.assertEqual(th.input_types, ((int, ), {}))
+    # The full Iterable[Union[int, TaggedOutput[...]]] is preserved.
+    output_type = th.output_types[0][0]
+    self.assertIsInstance(output_type, typehints.IterableTypeConstraint)
+    self.assertEqual(th.output_types[1], {})
+
+  def test_from_callable_tagged_output_only(self):
+    """A standalone TaggedOutput annotation passes through from_callable."""
+    def fn(x: int) -> TaggedOutput[typing.Literal['errors'], str]:
+      pass
+
+    th = decorators.IOTypeHints.from_callable(fn)
+    self.assertEqual(th.input_types, ((int, ), {}))
+    # TaggedOutput[...] passes through convert_to_beam_type unchanged.
+    self.assertIs(typing.get_origin(th.output_types[0][0]), TaggedOutput)
+    self.assertEqual(th.output_types[1], {})
+
   def test_getcallargs_forhints(self):
     def fn(
         a: int,
@@ -426,5 +485,71 @@ class DecoratorsTest(unittest.TestCase):
     _ = ['a', 'b', 'c'] | Map(fn2)  # Doesn't raise - no input type hints.
 
 
+class ExtractTaggedFromTypeTest(unittest.TestCase):
+  """Tests for _extract_tagged_from_type (Beam-level type extraction)."""
+  def test_simple_type_no_extraction(self):
+    main, tagged = decorators._extract_tagged_from_type(int)
+    self.assertEqual(main, int)
+    self.assertEqual(tagged, {})
+
+  def test_beam_union_no_tagged(self):
+    t = typehints.Union[int, str]
+    main, tagged = decorators._extract_tagged_from_type(t)
+    self.assertEqual(main, t)
+    self.assertEqual(tagged, {})
+
+  def test_standalone_tagged_output(self):
+    t = TaggedOutput[typing.Literal['errors'], str]
+    main, tagged = decorators._extract_tagged_from_type(t)
+    self.assertIs(main, decorators._NO_MAIN_TYPE)
+    self.assertEqual(tagged, {'errors': str})
+
+  def test_beam_union_with_tagged(self):
+    t = convert_to_beam_type(int | TaggedOutput[typing.Literal['errors'], str])
+    main, tagged = decorators._extract_tagged_from_type(t)
+    self.assertEqual(main, int)
+    self.assertEqual(tagged, {'errors': str})
+
+  def test_beam_union_multiple_tagged(self):
+    t = convert_to_beam_type(
+        int | TaggedOutput[typing.Literal['errors'], str]
+        | TaggedOutput[typing.Literal['warnings'], str])
+    main, tagged = decorators._extract_tagged_from_type(t)
+    self.assertEqual(main, int)
+    self.assertEqual(tagged, {'errors': str, 'warnings': str})
+
+  def test_beam_union_multiple_main_types(self):
+    t = convert_to_beam_type(
+        int | str | TaggedOutput[typing.Literal['errors'], bytes])
+    main, tagged = decorators._extract_tagged_from_type(t)
+    self.assertIsInstance(main, typehints.UnionConstraint)
+    self.assertIn(int, main.union_types)
+    self.assertIn(str, main.union_types)
+    self.assertEqual(tagged, {'errors': bytes})
+
+  def test_beam_union_tagged_only(self):
+    t = convert_to_beam_type(
+        TaggedOutput[typing.Literal['errors'], str]
+        | TaggedOutput[typing.Literal['warnings'], int])
+    main, tagged = decorators._extract_tagged_from_type(t)
+    self.assertIs(main, decorators._NO_MAIN_TYPE)
+    self.assertEqual(tagged, {'errors': str, 'warnings': int})
+
+  def test_bare_tagged_output_standalone(self):
+    with self.assertLogs(level='WARNING') as cm:
+      main, tagged = decorators._extract_tagged_from_type(TaggedOutput)
+    self.assertIn('Bare TaggedOutput will be ignored', cm.output[0])
+    self.assertIs(main, decorators._NO_MAIN_TYPE)
+    self.assertEqual(tagged, {})
+
+  def test_bare_tagged_output_in_union(self):
+    with self.assertLogs(level='WARNING') as cm:
+      t = convert_to_beam_type(str | TaggedOutput)
+      main, tagged = decorators._extract_tagged_from_type(t)
+    self.assertIn('Bare TaggedOutput will be ignored', cm.output[0])
+    self.assertEqual(main, str)
+    self.assertEqual(tagged, {})
+
+
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py 
b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py
new file mode 100644
index 00000000000..c06f68fb88a
--- /dev/null
+++ b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py
@@ -0,0 +1,356 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for tagged output type hints.
+
+This tests the implementation of type hints for tagged outputs via three 
styles:
+
+1. Decorator style:
+   @with_output_types(int, errors=str, warnings=str)
+   class MyDoFn(beam.DoFn):
+     ...
+
+2. Method chain style:
+   beam.ParDo(MyDoFn()).with_output_types(int, errors=str)
+
+3. Function annotation style:
+   def fn(element) -> int | TaggedOutput[Literal['errors'], str]:
+     ...
+"""
+
+# pytype: skip-file
+
+import unittest
+from typing import Iterable
+from typing import Literal
+from typing import Union
+
+import apache_beam as beam
+from apache_beam.pvalue import TaggedOutput
+from apache_beam.typehints import with_output_types
+from apache_beam.typehints.decorators import IOTypeHints
+
+
+class IOTypeHintsTaggedOutputTest(unittest.TestCase):
+  """Tests for IOTypeHints.tagged_output_types() accessor."""
+  def test_empty_hints_returns_empty_dict(self):
+    empty = IOTypeHints.empty()
+    self.assertEqual(empty.tagged_output_types(), {})
+
+  def test_with_tagged_types(self):
+    hints = IOTypeHints.empty().with_output_types(int, errors=str, 
warnings=str)
+    self.assertEqual(
+        hints.tagged_output_types(), {
+            'errors': str, 'warnings': str
+        })
+
+  def test_simple_output_type_with_tagged_types(self):
+    """simple_output_type() should still return main type when tags present."""
+    hints = IOTypeHints.empty().with_output_types(int, errors=str, 
warnings=str)
+    self.assertEqual(hints.simple_output_type('test'), int)
+
+    hints = IOTypeHints.empty().with_output_types(
+        Union[int, str], errors=str, warnings=str)
+    self.assertEqual(hints.simple_output_type('test'), Union[int, str])
+
+  def test_without_tagged_types(self):
+    """Without tagged types, tagged_output_types() returns empty dict."""
+    hints = IOTypeHints.empty().with_output_types(int)
+    self.assertEqual(hints.tagged_output_types(), {})
+    self.assertEqual(hints.simple_output_type('test'), int)
+
+
+class DecoratorStyleTaggedOutputTest(unittest.TestCase):
+  """Tests for @with_output_types decorator style across all transforms."""
+  def test_pardo_decorator_pipeline(self):
+    """Test that tagged types propagate through ParDo pipeline."""
+    @with_output_types(int, errors=str)
+    class MyDoFn(beam.DoFn):
+      def process(self, element):
+        if element < 0:
+          yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+        else:
+          yield element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.ParDo(MyDoFn()).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_map_decorator_pipeline(self):
+    """Test that tagged types propagate through Map."""
+    @with_output_types(int, errors=str)
+    def mapfn(element):
+      if element < 0:
+        return beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+      else:
+        return element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.Map(mapfn).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_flatmap_decorator_pipeline(self):
+    """Test that tagged types propagate through FlatMap."""
+    @with_output_types(Iterable[int], errors=Iterable[str])
+    def flatmapfn(element):
+      if element < 0:
+        yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+      else:
+        yield element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.FlatMap(flatmapfn).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_maptuple_decorator_pipeline(self):
+    """Test that tagged types propagate through MapTuple."""
+    @with_output_types(int, errors=str)
+    def maptuplefn(key, value):
+      if value < 0:
+        return beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}')
+      else:
+        return value * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([('a', -1), ('b', 2), ('c', 3)])
+          | beam.MapTuple(maptuplefn).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_flatmaptuple_decorator_pipeline(self):
+    """Test that tagged types propagate through FlatMapTuple."""
+    @with_output_types(Iterable[int], errors=Iterable[str])
+    def flatmaptuplefn(key, value):
+      if value < 0:
+        yield beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}')
+      else:
+        yield value * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([('a', -1), ('b', 2), ('c', 3)])
+          | beam.FlatMapTuple(flatmaptuplefn).with_outputs(
+              'errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+
+class ChainStyleTaggedOutputTest(unittest.TestCase):
+  """Tests for .with_output_types() method chain style across all 
transforms."""
+  def test_pardo_chain_pipeline(self):
+    """Test ParDo with chained type hints."""
+    class SimpleDoFn(beam.DoFn):
+      def process(self, element):
+        if element < 0:
+          yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+        else:
+          yield element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.ParDo(SimpleDoFn()).with_output_types(
+              int, errors=str).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_map_chain_pipeline(self):
+    """Test Map with chained type hints."""
+    def mapfn(element):
+      if element < 0:
+        return beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+      else:
+        return element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.Map(mapfn).with_output_types(int, errors=str).with_outputs(
+              'errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_flatmap_chain_pipeline(self):
+    """Test FlatMap with chained type hints.
+
+    Note: For FlatMap.with_output_types(), specify the element type directly
+    (int), not wrapped in Iterable. The transform handles iteration internally.
+    """
+    def flatmapfn(element):
+      if element < 0:
+        yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+      else:
+        yield element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.FlatMap(flatmapfn).with_output_types(
+              int, errors=str).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_maptuple_chain_pipeline(self):
+    """Test MapTuple with chained type hints."""
+    def maptuplefn(key, value):
+      if value < 0:
+        return beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}')
+      else:
+        return value * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([('a', -1), ('b', 2), ('c', 3)])
+          | beam.MapTuple(maptuplefn).with_output_types(
+              int, errors=str).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_flatmaptuple_chain_pipeline(self):
+    """Test FlatMapTuple with chained type hints.
+
+    Note: For FlatMapTuple.with_output_types(), specify the element type
+    directly (int), not wrapped in Iterable.
+    """
+    def flatmaptuplefn(key, value):
+      if value < 0:
+        yield beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}')
+      else:
+        yield value * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([('a', -1), ('b', 2), ('c', 3)])
+          | beam.FlatMapTuple(flatmaptuplefn).with_output_types(
+              int, errors=str).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+
+class AnnotationStyleTaggedOutputTest(unittest.TestCase):
+  """Tests for function annotation style across all transforms."""
+  def test_map_annotation_union(self):
+    """Test Map with Union[int, TaggedOutput[...]] annotation."""
+    def mapfn(element: int) -> int | TaggedOutput[Literal['errors'], str]:
+      if element < 0:
+        return beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+      else:
+        return element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.Map(mapfn).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_map_annotation_multiple_tags(self):
+    """Test Map with multiple TaggedOutput types in annotation."""
+    def mapfn(
+        element: int
+    ) -> int | TaggedOutput[Literal['errors'],
+                            str] | TaggedOutput[Literal['warnings'], str]:
+      if element < 0:
+        return beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+      elif element == 0:
+        return beam.pvalue.TaggedOutput('warnings', 'Zero value')
+      else:
+        return element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.Map(mapfn).with_outputs('errors', 'warnings', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+      self.assertEqual(results.warnings.element_type, str)
+
+  def test_flatmap_annotation_iterable(self):
+    """Test FlatMap with Iterable[int | TaggedOutput[...]] annotation."""
+    def flatmapfn(
+        element: int) -> Iterable[int | TaggedOutput[Literal['errors'], str]]:
+      if element < 0:
+        yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+      else:
+        yield element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.FlatMap(flatmapfn).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_pardo_annotation_process_method(self):
+    """Test DoFn with process method annotation."""
+    class AnnotatedDoFn(beam.DoFn):
+      def process(
+          self,
+          element: int) -> Iterable[int | TaggedOutput[Literal['errors'], 
str]]:
+        if element < 0:
+          yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+        else:
+          yield element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.ParDo(AnnotatedDoFn()).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/typehints/typehints_test.py 
b/sdks/python/apache_beam/typehints/typehints_test.py
index 0bbc21f6739..cec83038008 100644
--- a/sdks/python/apache_beam/typehints/typehints_test.py
+++ b/sdks/python/apache_beam/typehints/typehints_test.py
@@ -1421,7 +1421,7 @@ class OutputDecoratorTestCase(TypeHintTestCase):
       return 5, 'bar'
 
   def test_no_kwargs_accepted(self):
-    with self.assertRaisesRegex(ValueError, r'must be positional'):
+    with self.assertRaisesRegex(ValueError, r'single positional argument'):
 
       @with_output_types(m=int)
       def unused_foo():


Reply via email to