robertwb commented on code in PR #37434:
URL: https://github.com/apache/beam/pull/37434#discussion_r2743369949


##########
sdks/python/apache_beam/transforms/ptransform.py:
##########
@@ -430,10 +433,22 @@ def with_output_types(self, type_hint):
       PTransform: A reference to the instance of this particular
       :class:`PTransform` object. This allows chaining type-hinting related
       methods.
+
+    Example::
+      result = pcoll | beam.ParDo(MyDoFn()).with_output_types(
+          int,  # main output type
+          errors=str,  # 'errors' tagged output type
+          warnings=str  # 'warnings' tagged output type
+      ).with_outputs('errors', 'warnings', main='main')

Review Comment:
   The redundancy here is unfortunate... I wonder if we could combine these 
into a single with_outputs_and_types or something. (If it's not easy, this is 
fine to leave for future work.)



##########
sdks/python/apache_beam/transforms/core.py:
##########
@@ -1834,6 +1834,17 @@ def with_outputs(self, *tags, main=None, 
allow_unknown_tags=None):
       raise ValueError(
           'Main output tag %r must be different from side output tags %r.' %
           (main, tags))
+    type_hints = self.get_type_hints()
+    declared_tags = set(type_hints.tagged_output_types().keys())
+    requested_tags = set(tags)
+
+    unknown = requested_tags - declared_tags
+    if unknown and declared_tags:  # Only warn if type hints exist
+      logging.warning(
+          "Tags %s requested in with_outputs() but not declared "

Review Comment:
   fstring?



##########
sdks/python/apache_beam/typehints/tagged_output_typehints_test.py:
##########
@@ -0,0 +1,356 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for tagged output type hints.
+
+This tests the implementation of type hints for tagged outputs via three 
styles:
+
+1. Decorator style:
+   @with_output_types(int, errors=str, warnings=str)
+   class MyDoFn(beam.DoFn):
+     ...
+
+2. Method chain style:
+   beam.ParDo(MyDoFn()).with_output_types(int, errors=str)
+
+3. Function annotation style:
+   def fn(element) -> int | TaggedOutput[Literal['errors'], str]:
+     ...
+"""
+
+# pytype: skip-file
+
+import unittest
+from typing import Iterable
+from typing import Literal
+from typing import Union
+
+import apache_beam as beam
+from apache_beam.pvalue import TaggedOutput
+from apache_beam.typehints import with_output_types
+from apache_beam.typehints.decorators import IOTypeHints
+
+
+class IOTypeHintsTaggedOutputTest(unittest.TestCase):
+  """Tests for IOTypeHints.tagged_output_types() accessor."""
+  def test_empty_hints_returns_empty_dict(self):
+    empty = IOTypeHints.empty()
+    self.assertEqual(empty.tagged_output_types(), {})
+
+  def test_with_tagged_types(self):
+    hints = IOTypeHints.empty().with_output_types(int, errors=str, 
warnings=str)
+    self.assertEqual(
+        hints.tagged_output_types(), {
+            'errors': str, 'warnings': str
+        })
+
+  def test_simple_output_type_with_tagged_types(self):
+    """simple_output_type() should still return main type when tags present."""
+    hints = IOTypeHints.empty().with_output_types(int, errors=str, 
warnings=str)
+    self.assertEqual(hints.simple_output_type('test'), int)
+
+    hints = IOTypeHints.empty().with_output_types(
+        Union[int, str], errors=str, warnings=str)
+    self.assertEqual(hints.simple_output_type('test'), Union[int, str])
+
+  def test_without_tagged_types(self):
+    """Without tagged types, tagged_output_types() returns empty dict."""
+    hints = IOTypeHints.empty().with_output_types(int)
+    self.assertEqual(hints.tagged_output_types(), {})
+    self.assertEqual(hints.simple_output_type('test'), int)
+
+
+class DecoratorStyleTaggedOutputTest(unittest.TestCase):
+  """Tests for @with_output_types decorator style across all transforms."""
+  def test_pardo_decorator_pipeline(self):
+    """Test that tagged types propagate through ParDo pipeline."""
+    @with_output_types(int, errors=str)
+    class MyDoFn(beam.DoFn):
+      def process(self, element):
+        if element < 0:
+          yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+        else:
+          yield element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.ParDo(MyDoFn()).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_map_decorator_pipeline(self):
+    """Test that tagged types propagate through Map."""
+    @with_output_types(int, errors=str)
+    def mapfn(element):
+      if element < 0:
+        return beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+      else:
+        return element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.Map(mapfn).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_flatmap_decorator_pipeline(self):
+    """Test that tagged types propagate through FlatMap."""
+    @with_output_types(Iterable[int], errors=str)
+    def flatmapfn(element):
+      if element < 0:
+        yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+      else:
+        yield element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.FlatMap(flatmapfn).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_maptuple_decorator_pipeline(self):
+    """Test that tagged types propagate through MapTuple."""
+    @with_output_types(int, errors=str)
+    def maptuplefn(key, value):
+      if value < 0:
+        return beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}')
+      else:
+        return value * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([('a', -1), ('b', 2), ('c', 3)])
+          | beam.MapTuple(maptuplefn).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_flatmaptuple_decorator_pipeline(self):
+    """Test that tagged types propagate through FlatMapTuple."""
+    @with_output_types(Iterable[int], errors=str)
+    def flatmaptuplefn(key, value):
+      if value < 0:
+        yield beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}')
+      else:
+        yield value * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([('a', -1), ('b', 2), ('c', 3)])
+          | beam.FlatMapTuple(flatmaptuplefn).with_outputs(
+              'errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+
+class ChainStyleTaggedOutputTest(unittest.TestCase):
+  """Tests for .with_output_types() method chain style across all 
transforms."""
+  def test_pardo_chain_pipeline(self):
+    """Test ParDo with chained type hints."""
+    class SimpleDoFn(beam.DoFn):
+      def process(self, element):
+        if element < 0:
+          yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+        else:
+          yield element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.ParDo(SimpleDoFn()).with_output_types(
+              int, errors=str).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_map_chain_pipeline(self):
+    """Test Map with chained type hints."""
+    def mapfn(element):
+      if element < 0:
+        return beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+      else:
+        return element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.Map(mapfn).with_output_types(int, errors=str).with_outputs(
+              'errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_flatmap_chain_pipeline(self):
+    """Test FlatMap with chained type hints.
+
+    Note: For FlatMap.with_output_types(), specify the element type directly
+    (int), not wrapped in Iterable. The transform handles iteration internally.
+    """
+    def flatmapfn(element):
+      if element < 0:
+        yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+      else:
+        yield element * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([-1, 0, 1, 2])
+          | beam.FlatMap(flatmapfn).with_output_types(
+              int, errors=str).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_maptuple_chain_pipeline(self):
+    """Test MapTuple with chained type hints."""
+    def maptuplefn(key, value):
+      if value < 0:
+        return beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}')
+      else:
+        return value * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([('a', -1), ('b', 2), ('c', 3)])
+          | beam.MapTuple(maptuplefn).with_output_types(
+              int, errors=str).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+  def test_flatmaptuple_chain_pipeline(self):
+    """Test FlatMapTuple with chained type hints.
+
+    Note: For FlatMapTuple.with_output_types(), specify the element type
+    directly (int), not wrapped in Iterable.
+    """
+    def flatmaptuplefn(key, value):
+      if value < 0:
+        yield beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}')
+      else:
+        yield value * 2
+
+    with beam.Pipeline() as p:
+      results = (
+          p
+          | beam.Create([('a', -1), ('b', 2), ('c', 3)])
+          | beam.FlatMapTuple(flatmaptuplefn).with_output_types(
+              int, errors=str).with_outputs('errors', main='main'))
+
+      self.assertEqual(results.main.element_type, int)
+      self.assertEqual(results.errors.element_type, str)
+
+
+class AnnotationStyleTaggedOutputTest(unittest.TestCase):
+  """Tests for function annotation style across all transforms."""
+  def test_map_annotation_union(self):

Review Comment:
   Woo hoo!



##########
sdks/python/apache_beam/typehints/decorators.py:
##########
@@ -182,6 +187,140 @@ def disable_type_annotations():
 TRACEBACK_LIMIT = 5
 
 
+def _tag_and_type(t):
+  """Extract tag name and value type from TaggedOutput[Literal['tag'], Type].
+
+  Returns raw Python types - conversion to beam types happens in
+  _extract_output_types.
+  """
+  args = get_args(t)
+  if len(args) != 2:
+    raise TypeError(
+        f"TaggedOutput expects 2 type parameters, got {len(args)}: {t}")
+
+  literal_type, value_type = args
+
+  if get_origin(literal_type) is not Literal:
+    raise TypeError(
+        f"First type parameter of TaggedOutput must be Literal['tag_name'], "
+        f"got {literal_type}. Example: TaggedOutput[Literal['errors'], str]")
+
+  tag_string = get_args(literal_type)[0]
+  return tag_string, value_type
+
+
+def _contains_tagged_output(t):
+  """Check if type contains TaggedOutput at a meaningful position.
+
+  TaggedOutput only makes sense in these patterns:
+  - TaggedOutput[...]
+  - X | TaggedOutput[...]
+  - Iterable[TaggedOutput[...]]
+  - Iterable[X | TaggedOutput[...]]
+  """
+  def _is_tagged(typ):
+    return get_origin(typ) is TaggedOutput or typ is TaggedOutput
+
+  # TaggedOutput[...]
+  if _is_tagged(t):
+    return True
+
+  origin = get_origin(t)
+  args = get_args(t)
+
+  # X | TaggedOutput[...]
+  if origin is Union:
+    return any(_is_tagged(arg) for arg in args)
+
+  # Iterable[...]
+  if origin is collections.abc.Iterable and len(args) == 1:
+    inner = args[0]
+    # Iterable[TaggedOutput[...]]
+    if _is_tagged(inner):
+      return True
+    # Iterable[X | TaggedOutput[...]]
+    if get_origin(inner) is Union:
+      return any(_is_tagged(arg) for arg in get_args(inner))
+
+  return False
+
+
+def _extract_main_and_tagged(t):
+  """Extract main type and tagged types from a type annotation.
+
+  Returns:
+    (main_type, tagged_dict) where main_type is the type without TaggedOutput
+    annotations (or None if no main type), and tagged_dict maps tag names to
+    their types.
+  """
+  if get_origin(t) is TaggedOutput:
+    tag, typ = _tag_and_type(t)
+    return None, {tag: typ}
+
+  if t is TaggedOutput:
+    raise TypeError(
+        "TaggedOutput in return type must include type parameters: "
+        "TaggedOutput[Literal['tag_name'], ValueType]")
+
+  if get_origin(t) is not Union:
+    return t, {}
+
+  main_types = []
+  tagged_types = {}
+  for arg in get_args(t):
+    if get_origin(arg) is TaggedOutput:
+      tag, typ = _tag_and_type(arg)
+      tagged_types[tag] = typ
+    elif arg is TaggedOutput:
+      raise TypeError(
+          "TaggedOutput in return type must include type parameters: "
+          "TaggedOutput[Literal['tag_name'], ValueType]")
+    else:
+      main_types.append(arg)
+
+  if len(main_types) == 0:
+    main_type = None
+  elif len(main_types) == 1:
+    main_type = main_types[0]
+  else:
+    main_type = Union[tuple(main_types)]
+
+  return main_type, tagged_types
+
+
+def _extract_output_types(return_annotation):
+  """Parse return annotation into (main_types, tagged_types).
+
+  For tagged outputs to be extracted from generator/iterator functions,
+  users must explicitly use Iterable[T | TaggedOutput[...]] as return type.
+
+  Returns raw Python types. Conversion to beam types happens in from_callable.
+  """
+  if return_annotation == inspect.Signature.empty:
+    return [Any], {}
+
+  # Early return if no TaggedOutput
+  if not _contains_tagged_output(return_annotation):

Review Comment:
   IMHO, _extract_main_and_tagged should be able to handle the degenerate case 
of no tagged outputs, and so we don't have to check for this and exit early. 



##########
sdks/python/apache_beam/typehints/decorators.py:
##########
@@ -182,6 +187,140 @@ def disable_type_annotations():
 TRACEBACK_LIMIT = 5
 
 
+def _tag_and_type(t):
+  """Extract tag name and value type from TaggedOutput[Literal['tag'], Type].
+
+  Returns raw Python types - conversion to beam types happens in
+  _extract_output_types.
+  """
+  args = get_args(t)
+  if len(args) != 2:
+    raise TypeError(
+        f"TaggedOutput expects 2 type parameters, got {len(args)}: {t}")
+
+  literal_type, value_type = args
+
+  if get_origin(literal_type) is not Literal:
+    raise TypeError(
+        f"First type parameter of TaggedOutput must be Literal['tag_name'], "
+        f"got {literal_type}. Example: TaggedOutput[Literal['errors'], str]")
+
+  tag_string = get_args(literal_type)[0]
+  return tag_string, value_type
+
+
+def _contains_tagged_output(t):
+  """Check if type contains TaggedOutput at a meaningful position.
+
+  TaggedOutput only makes sense in these patterns:
+  - TaggedOutput[...]
+  - X | TaggedOutput[...]
+  - Iterable[TaggedOutput[...]]
+  - Iterable[X | TaggedOutput[...]]
+  """
+  def _is_tagged(typ):
+    return get_origin(typ) is TaggedOutput or typ is TaggedOutput
+
+  # TaggedOutput[...]
+  if _is_tagged(t):
+    return True
+
+  origin = get_origin(t)
+  args = get_args(t)
+
+  # X | TaggedOutput[...]
+  if origin is Union:
+    return any(_is_tagged(arg) for arg in args)
+
+  # Iterable[...]
+  if origin is collections.abc.Iterable and len(args) == 1:
+    inner = args[0]
+    # Iterable[TaggedOutput[...]]
+    if _is_tagged(inner):
+      return True
+    # Iterable[X | TaggedOutput[...]]
+    if get_origin(inner) is Union:
+      return any(_is_tagged(arg) for arg in get_args(inner))
+
+  return False
+
+
+def _extract_main_and_tagged(t):
+  """Extract main type and tagged types from a type annotation.
+
+  Returns:
+    (main_type, tagged_dict) where main_type is the type without TaggedOutput
+    annotations (or None if no main type), and tagged_dict maps tag names to
+    their types.
+  """
+  if get_origin(t) is TaggedOutput:
+    tag, typ = _tag_and_type(t)
+    return None, {tag: typ}
+
+  if t is TaggedOutput:
+    raise TypeError(

Review Comment:
   We should probably allow this as `pvalue.TaggedOutput[*, Any]`, though 
perhaps with a warning. 



##########
sdks/python/apache_beam/typehints/decorators.py:
##########
@@ -182,6 +187,140 @@ def disable_type_annotations():
 TRACEBACK_LIMIT = 5
 
 
+def _tag_and_type(t):
+  """Extract tag name and value type from TaggedOutput[Literal['tag'], Type].
+
+  Returns raw Python types - conversion to beam types happens in
+  _extract_output_types.
+  """
+  args = get_args(t)
+  if len(args) != 2:
+    raise TypeError(
+        f"TaggedOutput expects 2 type parameters, got {len(args)}: {t}")
+
+  literal_type, value_type = args
+
+  if get_origin(literal_type) is not Literal:
+    raise TypeError(
+        f"First type parameter of TaggedOutput must be Literal['tag_name'], "
+        f"got {literal_type}. Example: TaggedOutput[Literal['errors'], str]")
+
+  tag_string = get_args(literal_type)[0]
+  return tag_string, value_type
+
+
+def _contains_tagged_output(t):
+  """Check if type contains TaggedOutput at a meaningful position.
+
+  TaggedOutput only makes sense in these patterns:
+  - TaggedOutput[...]
+  - X | TaggedOutput[...]
+  - Iterable[TaggedOutput[...]]
+  - Iterable[X | TaggedOutput[...]]
+  """
+  def _is_tagged(typ):
+    return get_origin(typ) is TaggedOutput or typ is TaggedOutput
+
+  # TaggedOutput[...]
+  if _is_tagged(t):
+    return True
+
+  origin = get_origin(t)
+  args = get_args(t)
+
+  # X | TaggedOutput[...]
+  if origin is Union:
+    return any(_is_tagged(arg) for arg in args)
+
+  # Iterable[...]
+  if origin is collections.abc.Iterable and len(args) == 1:
+    inner = args[0]
+    # Iterable[TaggedOutput[...]]
+    if _is_tagged(inner):
+      return True
+    # Iterable[X | TaggedOutput[...]]
+    if get_origin(inner) is Union:
+      return any(_is_tagged(arg) for arg in get_args(inner))
+
+  return False
+
+
+def _extract_main_and_tagged(t):
+  """Extract main type and tagged types from a type annotation.
+
+  Returns:
+    (main_type, tagged_dict) where main_type is the type without TaggedOutput
+    annotations (or None if no main type), and tagged_dict maps tag names to
+    their types.
+  """
+  if get_origin(t) is TaggedOutput:
+    tag, typ = _tag_and_type(t)
+    return None, {tag: typ}
+
+  if t is TaggedOutput:
+    raise TypeError(
+        "TaggedOutput in return type must include type parameters: "
+        "TaggedOutput[Literal['tag_name'], ValueType]")
+
+  if get_origin(t) is not Union:
+    return t, {}
+
+  main_types = []
+  tagged_types = {}
+  for arg in get_args(t):
+    if get_origin(arg) is TaggedOutput:
+      tag, typ = _tag_and_type(arg)
+      tagged_types[tag] = typ
+    elif arg is TaggedOutput:
+      raise TypeError(
+          "TaggedOutput in return type must include type parameters: "
+          "TaggedOutput[Literal['tag_name'], ValueType]")
+    else:
+      main_types.append(arg)
+
+  if len(main_types) == 0:
+    main_type = None
+  elif len(main_types) == 1:
+    main_type = main_types[0]
+  else:
+    main_type = Union[tuple(main_types)]
+
+  return main_type, tagged_types
+
+
+def _extract_output_types(return_annotation):
+  """Parse return annotation into (main_types, tagged_types).
+
+  For tagged outputs to be extracted from generator/iterator functions,
+  users must explicitly use Iterable[T | TaggedOutput[...]] as return type.
+
+  Returns raw Python types. Conversion to beam types happens in from_callable.
+  """
+  if return_annotation == inspect.Signature.empty:
+    return [Any], {}
+
+  # Early return if no TaggedOutput
+  if not _contains_tagged_output(return_annotation):
+    return [return_annotation], {}
+
+  # Iterable[T | TaggedOutput[...]]
+  if get_origin(return_annotation) is collections.abc.Iterable:
+    yield_type = get_args(return_annotation)[0]
+    clean_yield, tagged_types = _extract_main_and_tagged(yield_type)
+    clean_main = clean_yield if clean_yield else Any
+    return [Iterable[clean_main]], tagged_types

Review Comment:
   Seems a bit asymmetric here to be wrapping the main type in an iterable but 
not the tagged types. Should we instead be calling this after the (higher 
level) iterable unwrapping? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to