damccorm commented on code in PR #29938:
URL: https://github.com/apache/beam/pull/29938#discussion_r1446117168


##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -314,11 +327,34 @@ def expand(
           ptransform_list[i].artifact_mode = self._artifact_mode
 
     for ptransform in ptransform_list:
-      pcoll = pcoll | ptransform
-
+      if self._with_exception_handling:
+        if hasattr(ptransform, 'with_exception_handling'):
+          ptransform = ptransform.with_exception_handling(
+              **self._exception_handling_args)
+          pcoll, bad_results = pcoll | ptransform
+          # RunInference outputs a RunInferenceDLQ instead of a PCollection.
+          # since TFTProcessHandler and RunInferene are supported, try to infer
+          # the type of bad_results and append it to the list of errors.
+          if isinstance(bad_results, RunInferenceDLQ):
+            upstream_errors.append(bad_results.failed_inferences)
+          elif isinstance(bad_results, beam.PCollection):
+            upstream_errors.append(bad_results)
+          else:
+            raise NotImplementedError(
+                f'Unexpected type for bad_results: {type(bad_results)}')
+      else:
+        pcoll = pcoll | ptransform
     _ = (
         pcoll.pipeline
         | "MLTransformMetricsUsage" >> MLTransformMetricsUsage(self))
+    if self._with_exception_handling:
+      bad_pcoll = (
+          upstream_errors
+          | beam.Flatten()
+          | beam.Map(
+              lambda x: beam.Row(
+                  element=x[0], msg=str(x[1][1]), stack=str(x[1][2]))))

Review Comment:
   I think it would be helpful to also have the operation (e.g. 
`VertexAITextEmbeddings`) that this came from since I don't think it will be 
instantly obvious from the stacktrace - we could probably get this by adding a 
map step to `bad_results` above. 



##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -517,3 +517,8 @@ def expand(
           | "ConvertToRowType" >>
           beam.Map(lambda x: beam.Row(**x)).with_output_types(row_type))
       return transformed_dataset
+
+  def with_exception_handling(self):
+    raise NotImplementedError(
+        "with_exception_handling is not supported for tensorflow-transform "
+        "data processing transforms.")

Review Comment:
   It might be good to add a sentence along the lines of `If you want to use 
exception handling with other MLTransform operations, separate them into a 
different MLTransform instance`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to