derrickaw commented on code in PR #35952:
URL: https://github.com/apache/beam/pull/35952#discussion_r2326215515


##########
sdks/python/apache_beam/yaml/yaml_transform.py:
##########
@@ -522,6 +549,242 @@ def expand_leaf_transform(spec, scope):
         f'{type(outputs)}')
 
 
+def expand_output_schema_transform(spec, outputs, error_handling_spec):
+  """Applies a `Validate` transform to the output of another transform.
+
+  This function is called when an `output_schema` is defined on a transform.
+  It wraps the original transform's output(s) with a `Validate` transform
+  to ensure the data conforms to the specified schema.
+
+  If the original transform has error handling configured, validation errors
+  will be routed to the specified error output. If not, validation failures
+  will cause the pipeline to fail.
+
+  Args:
+    spec (dict): The `output_schema` specification from the YAML config.
+    outputs (beam.PCollection or dict[str, beam.PCollection]): The output(s)
+      from the transform to be validated.
+    error_handling_spec (dict): The `error_handling` configuration from the
+      original transform.
+
+  Returns:
+    The validated PCollection(s). If error handling is enabled, this will be a
+    dictionary containing the 'good' output and any error outputs.
+
+  Raises:
+    ValueError: If `error_handling` is incorrectly specified within the
+      `output_schema` spec itself, or if the main output of a multi-output
+      transform cannot be determined.
+  """
+  if 'error_handling' in spec:
+    raise ValueError(
+        'error_handling config is not supported directly in '
+        'the output_schema. Please use error_handling config in '
+        'the transform, if possible, or use ValidateWithSchema transform '
+        'instead.')
+
+  # Strip metadata such as __line__ and __uuid__ as these will interfere with
+  # the validation downstream.
+  clean_schema = SafeLineLoader.strip_metadata(spec)
+
+  # If no error handling is specified for the main transform, warn the user
+  # that the pipeline may fail if any output data fails the output schema
+  # validation.
+  if not error_handling_spec:
+    _LOGGER.warning("Output_schema config is attached to a transform that has 
"\
+    "no error_handling config specified. Any failures validating on output" \
+    "schema will fail the pipeline unless the user specifies an" \
+    "error_handling config on a capable transform or the user can remove the" \
+    "output_schema config on this transform and add a ValidateWithSchema " \
+    "transform downstream of the current transform.")
+
+  # The transform produced outputs with a single beam.PCollection
+  if isinstance(outputs, beam.PCollection):
+    outputs = _enforce_schema(
+        outputs, 'EnforceOutputSchema', error_handling_spec, clean_schema)
+    if isinstance(outputs, dict):
+      main_tag = error_handling_spec.get('main_tag', 'good')
+      main_output = outputs.pop(main_tag)
+      if error_handling_spec:
+        error_output_tag = error_handling_spec.get('output')
+        if error_output_tag in outputs:
+          return {
+              'output': main_output,
+              error_output_tag: outputs.pop(error_output_tag)
+          }
+      return main_output
+
+  # The transform produced outputs with many named PCollections and need to
+  # determine which PCollection should be validated on.
+  elif isinstance(outputs, dict):
+    main_output_key = _get_main_output_key(spec, outputs)
+
+    validation_result = _enforce_schema(
+        outputs[main_output_key],
+        f'EnforceOutputSchema_{main_output_key}',
+        error_handling_spec,
+        clean_schema)
+    outputs = _integrate_validation_results(
+        outputs, validation_result, main_output_key, error_handling_spec)
+
+  return outputs
+
+
+def _get_main_output_key(spec, outputs):
+  """Determines the main output key from a dictionary of PCollections.
+
+  This is used to identify which output of a multi-output transform should be
+  validated against an `output_schema`.
+
+  The main output is determined using the following precedence:
+  1. An output with the key 'output'.
+  2. An output with the key 'good'.
+  3. The single output if there is only one.
+
+  Args:
+    spec: The transform specification, used for creating informative error
+      messages.
+    outputs: A dictionary mapping output tags to their corresponding
+      PCollections.
+
+  Returns:
+    The key of the main output PCollection.
+
+  Raises:
+    ValueError: If a main output cannot be determined because there are
+      multiple outputs and none are named 'output' or 'good'.
+  """
+  main_output_key = 'output'
+  if main_output_key not in outputs:
+    if 'good' in outputs:
+      main_output_key = 'good'
+    elif len(outputs) == 1:
+      main_output_key = next(iter(outputs.keys()))
+    else:
+      raise ValueError(
+          f"Transform {identify_object(spec)} has outputs "
+          f"{list(outputs.keys())}, but none are named 'output'. To apply "
+          "an 'output_schema', please ensure the transform has exactly one "
+          "output, or that the main output is named 'output'.")
+  return main_output_key
+
+
+def _integrate_validation_results(
+    outputs, validation_result, main_output_key, error_handling_spec):
+  """
+  Integrates the results of a validation transform back into the outputs of
+  the original transform.
+
+  This function handles merging the "good" and "bad" outputs from a
+  `Validate` transform with the existing outputs of the transform that was
+  validated.
+
+  Args:
+    outputs: The original dictionary of output PCollections from the transform.
+    validation_result: The output of the `Validate` transform. This can be a
+      single PCollection (if all elements passed) or a dictionary of
+      PCollections (if error handling was enabled for validation).
+    main_output_key: The key in the `outputs` dictionary corresponding to the
+      PCollection that was validated.
+    error_handling_spec: The error handling configuration of the original
+      transform.
+
+  Returns:
+    The updated dictionary of output PCollections, with validation results
+    integrated.
+
+  Raises:
+    ValueError: If the validation transform produces unexpected outputs.
+  """
+  if not isinstance(validation_result, dict):
+    outputs[main_output_key] = validation_result
+    return outputs
+
+  # The main output from validation is the good output.
+  main_tag = error_handling_spec.get('main_tag', 'good')
+  outputs[main_output_key] = validation_result.pop(main_tag)
+
+  if error_handling_spec:
+    error_output_tag = error_handling_spec['output']
+    if error_output_tag in validation_result:
+      schema_error_pcoll = validation_result.pop(error_output_tag)
+      if error_output_tag in outputs:
+        # The original transform also had an error output. Merge them.
+        outputs[error_output_tag] = (
+            (outputs[error_output_tag], schema_error_pcoll)
+            | f'FlattenErrors_{main_output_key}' >> beam.Flatten())
+      else:
+        # No error output in the original transform, so just add this one.
+        outputs[error_output_tag] = schema_error_pcoll

Review Comment:
   Not anymore.  Thanks.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to