This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c0de849c1ce [#24024] Stop wrapping light weight functions with 
Contextful as they add a lot of overhead for functions that are meant to do 
almost no work. (#24025)
c0de849c1ce is described below

commit c0de849c1ce9b56e70a99863ff6f9d57fbbd7c0c
Author: Luke Cwik <lc...@google.com>
AuthorDate: Thu Dec 8 12:48:49 2022 -0800

    [#24024] Stop wrapping light weight functions with Contextful as they add a 
lot of overhead for functions that are meant to do almost no work. (#24025)
    
    * Stop wrapping light weight functions with Contextful as they add a lot of 
overhead for functions that are meant to do almost no work.
    
    Fixes #24024
    
    * Fix Spark test expectations
---
 .../runners/spark/SparkRunnerDebuggerTest.java     |  11 +-
 .../beam/sdk/transforms/FlatMapElements.java       | 257 +++++++++++++--------
 .../apache/beam/sdk/transforms/MapElements.java    | 250 ++++++++++++--------
 3 files changed, 325 insertions(+), 193 deletions(-)

diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java
index 843d303bda8..157b3cb946e 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java
@@ -86,14 +86,15 @@ public class SparkRunnerDebuggerTest {
         
"sparkContext.<readFrom(org.apache.beam.sdk.transforms.Create$Values$CreateSource)>()\n"
             + "_.mapPartitions("
             + "new 
org.apache.beam.runners.spark.examples.WordCount$ExtractWordsFn())\n"
-            + "_.mapPartitions(new 
org.apache.beam.sdk.transforms.Contextful())\n"
+            + "_.mapPartitions(new 
org.apache.beam.sdk.transforms.Count$PerElement$1())\n"
             + "_.combineByKey(..., new 
org.apache.beam.sdk.transforms.Count$CountFn(), ...)\n"
             + "_.groupByKey()\n"
             + "_.map(new org.apache.beam.sdk.transforms.Sum$SumLongFn())\n"
-            + "_.mapPartitions(new 
org.apache.beam.sdk.transforms.Contextful())\n"
+            + "_.mapPartitions("
+            + "new 
org.apache.beam.runners.spark.SparkRunnerDebuggerTest$PlusOne())\n"
             + "sparkContext.union(...)\n"
             + "_.mapPartitions("
-            + "new org.apache.beam.sdk.transforms.Contextful())\n"
+            + "new 
org.apache.beam.runners.spark.examples.WordCount$FormatAsTextFn())\n"
             + "_.<org.apache.beam.sdk.io.TextIO$Write>";
 
     SparkRunnerDebugger.DebugSparkPipelineResult result =
@@ -142,11 +143,11 @@ public class SparkRunnerDebuggerTest {
             + "_.map(new 
org.apache.beam.sdk.transforms.windowing.FixedWindows())\n"
             + "_.mapPartitions(new org.apache.beam.runners.spark."
             + "SparkRunnerDebuggerTest$FormatKVFn())\n"
-            + "_.mapPartitions(new 
org.apache.beam.sdk.transforms.Contextful())\n"
+            + "_.mapPartitions(new 
org.apache.beam.sdk.transforms.Distinct$1())\n"
             + "_.groupByKey()\n"
             + "_.map(new 
org.apache.beam.sdk.transforms.Combine$IterableCombineFn())\n"
             + "_.mapPartitions(new 
org.apache.beam.sdk.transforms.Distinct$3())\n"
-            + "_.mapPartitions(new 
org.apache.beam.sdk.transforms.Contextful())\n"
+            + "_.mapPartitions(new 
org.apache.beam.sdk.transforms.WithKeys$1())\n"
             + "_.<org.apache.beam.sdk.io.kafka.AutoValue_KafkaIO_Write>";
 
     SparkRunnerDebugger.DebugSparkPipelineResult result =
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java
index 19e6f6465b5..9a87aba766f 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java
@@ -45,16 +45,13 @@ public class FlatMapElements<InputT, OutputT>
     extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> {
   private final transient @Nullable TypeDescriptor<InputT> inputType;
   private final transient @Nullable TypeDescriptor<OutputT> outputType;
-  private final transient @Nullable Object originalFnForDisplayData;
-  private final @Nullable Contextful<Fn<InputT, Iterable<OutputT>>> fn;
+  private final @Nullable Object fn;
 
   private FlatMapElements(
-      @Nullable Contextful<Fn<InputT, Iterable<OutputT>>> fn,
-      @Nullable Object originalFnForDisplayData,
+      @Nullable Object fn,
       @Nullable TypeDescriptor<InputT> inputType,
       TypeDescriptor<OutputT> outputType) {
     this.fn = fn;
-    this.originalFnForDisplayData = originalFnForDisplayData;
     this.inputType = inputType;
     this.outputType = outputType;
   }
@@ -83,14 +80,13 @@ public class FlatMapElements<InputT, OutputT>
    */
   public static <InputT, OutputT> FlatMapElements<InputT, OutputT> via(
       InferableFunction<? super InputT, ? extends Iterable<OutputT>> fn) {
-    Contextful<Fn<InputT, Iterable<OutputT>>> wrapped = (Contextful) 
Contextful.fn(fn);
     TypeDescriptor<OutputT> outputType =
         TypeDescriptors.extractFromTypeParameters(
             (TypeDescriptor<Iterable<OutputT>>) fn.getOutputTypeDescriptor(),
             Iterable.class,
             new TypeDescriptors.TypeVariableExtractor<Iterable<OutputT>, 
OutputT>() {});
     TypeDescriptor<InputT> inputType = (TypeDescriptor<InputT>) 
fn.getInputTypeDescriptor();
-    return new FlatMapElements<>(wrapped, fn, inputType, outputType);
+    return new FlatMapElements<>(fn, inputType, outputType);
   }
 
   /** Binary compatibility adapter for {@link #via(ProcessFunction)}. */
@@ -105,7 +101,7 @@ public class FlatMapElements<InputT, OutputT>
    */
   public static <OutputT> FlatMapElements<?, OutputT> into(
       final TypeDescriptor<OutputT> outputType) {
-    return new FlatMapElements<>(null, null, null, outputType);
+    return new FlatMapElements<>(null, null, outputType);
   }
 
   /**
@@ -123,8 +119,7 @@ public class FlatMapElements<InputT, OutputT>
    */
   public <NewInputT> FlatMapElements<NewInputT, OutputT> via(
       ProcessFunction<NewInputT, ? extends Iterable<OutputT>> fn) {
-    return new FlatMapElements<>(
-        (Contextful) Contextful.fn(fn), fn, TypeDescriptors.inputOf(fn), 
outputType);
+    return new FlatMapElements<>(fn, TypeDescriptors.inputOf(fn), outputType);
   }
 
   /** Binary compatibility adapter for {@link #via(ProcessFunction)}. */
@@ -137,57 +132,90 @@ public class FlatMapElements<InputT, OutputT>
   @Experimental(Kind.CONTEXTFUL)
   public <NewInputT> FlatMapElements<NewInputT, OutputT> via(
       Contextful<Fn<NewInputT, Iterable<OutputT>>> fn) {
-    return new FlatMapElements<>(
-        fn, fn.getClosure(), TypeDescriptors.inputOf(fn.getClosure()), 
outputType);
+    return new FlatMapElements<>(fn, TypeDescriptors.inputOf(fn.getClosure()), 
outputType);
   }
 
   @Override
   public PCollection<OutputT> expand(PCollection<? extends InputT> input) {
     checkArgument(fn != null, ".via() is required");
-    return input.apply(
-        "FlatMap",
-        ParDo.of(
-                new DoFn<InputT, OutputT>() {
-                  @ProcessElement
-                  public void processElement(ProcessContext c) throws 
Exception {
-                    Iterable<OutputT> res =
-                        fn.getClosure().apply(c.element(), 
Fn.Context.wrapProcessContext(c));
-                    for (OutputT output : res) {
-                      c.output(output);
+    if (fn instanceof Contextful) {
+      return input.apply(
+          "FlatMap",
+          ParDo.of(
+                  new FlatMapDoFn() {
+                    @ProcessElement
+                    public void processElement(ProcessContext c) throws 
Exception {
+                      Iterable<OutputT> res =
+                          ((Contextful<Fn<InputT, Iterable<OutputT>>>) fn)
+                              .getClosure()
+                              .apply(c.element(), 
Fn.Context.wrapProcessContext(c));
+                      for (OutputT output : res) {
+                        c.output(output);
+                      }
                     }
+                  })
+              .withSideInputs(
+                  ((Contextful<Fn<InputT, Iterable<OutputT>>>) fn)
+                      .getRequirements()
+                      .getSideInputs()));
+    } else if (fn instanceof ProcessFunction) {
+      return input.apply(
+          "FlatMap",
+          ParDo.of(
+              new FlatMapDoFn() {
+                @ProcessElement
+                public void processElement(
+                    @Element InputT element, OutputReceiver<OutputT> receiver) 
throws Exception {
+                  Iterable<OutputT> res =
+                      ((ProcessFunction<InputT, Iterable<OutputT>>) 
fn).apply(element);
+                  for (OutputT output : res) {
+                    receiver.output(output);
                   }
+                }
+              }));
+    } else {
+      throw new IllegalArgumentException(
+          String.format("Unknown type of fn class %s", fn.getClass()));
+    }
+  }
 
-                  @Override
-                  public TypeDescriptor<InputT> getInputTypeDescriptor() {
-                    return inputType;
-                  }
+  private abstract class FlatMapDoFn extends DoFn<InputT, OutputT> {
 
-                  @Override
-                  public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
-                    checkState(
-                        outputType != null,
-                        "%s output type descriptor was null; "
-                            + "this probably means that 
getOutputTypeDescriptor() was called after "
-                            + "serialization/deserialization, but it is only 
available prior to "
-                            + "serialization, for constructing a pipeline and 
inferring coders",
-                        FlatMapElements.class.getSimpleName());
-                    return outputType;
-                  }
+    @Override
+    public TypeDescriptor<InputT> getInputTypeDescriptor() {
+      return inputType;
+    }
 
-                  @Override
-                  public void populateDisplayData(DisplayData.Builder builder) 
{
-                    builder.delegate(FlatMapElements.this);
-                  }
-                })
-            .withSideInputs(fn.getRequirements().getSideInputs()));
+    @Override
+    public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+      checkState(
+          outputType != null,
+          "%s output type descriptor was null; "
+              + "this probably means that getOutputTypeDescriptor() was called 
after "
+              + "serialization/deserialization, but it is only available prior 
to "
+              + "serialization, for constructing a pipeline and inferring 
coders",
+          FlatMapElements.class.getSimpleName());
+      return outputType;
+    }
+
+    @Override
+    public void populateDisplayData(DisplayData.Builder builder) {
+      builder.delegate(FlatMapElements.this);
+    }
   }
 
   @Override
   public void populateDisplayData(DisplayData.Builder builder) {
     super.populateDisplayData(builder);
-    builder.add(DisplayData.item("class", 
originalFnForDisplayData.getClass()));
-    if (originalFnForDisplayData instanceof HasDisplayData) {
-      builder.include("fn", (HasDisplayData) originalFnForDisplayData);
+    Object fnForDisplayData;
+    if (fn instanceof Contextful) {
+      fnForDisplayData = ((Contextful<Fn<InputT, OutputT>>) fn).getClosure();
+    } else {
+      fnForDisplayData = fn;
+    }
+    builder.add(DisplayData.item("class", fnForDisplayData.getClass()));
+    if (fnForDisplayData instanceof HasDisplayData) {
+      builder.include("fn", (HasDisplayData) fnForDisplayData);
     }
   }
 
@@ -203,8 +231,7 @@ public class FlatMapElements<InputT, OutputT>
   @Experimental(Kind.WITH_EXCEPTIONS)
   public <NewFailureT> FlatMapWithFailures<InputT, OutputT, NewFailureT> 
exceptionsInto(
       TypeDescriptor<NewFailureT> failureTypeDescriptor) {
-    return new FlatMapWithFailures<>(
-        fn, originalFnForDisplayData, inputType, outputType, null, 
failureTypeDescriptor);
+    return new FlatMapWithFailures<>(fn, inputType, outputType, null, 
failureTypeDescriptor);
   }
 
   /**
@@ -235,12 +262,7 @@ public class FlatMapElements<InputT, OutputT>
   public <FailureT> FlatMapWithFailures<InputT, OutputT, FailureT> 
exceptionsVia(
       InferableFunction<ExceptionElement<InputT>, FailureT> exceptionHandler) {
     return new FlatMapWithFailures<>(
-        fn,
-        originalFnForDisplayData,
-        inputType,
-        outputType,
-        exceptionHandler,
-        exceptionHandler.getOutputTypeDescriptor());
+        fn, inputType, outputType, exceptionHandler, 
exceptionHandler.getOutputTypeDescriptor());
   }
 
   /** A {@code PTransform} that adds exception handling to {@link 
FlatMapElements}. */
@@ -251,19 +273,16 @@ public class FlatMapElements<InputT, OutputT>
     private final transient TypeDescriptor<InputT> inputType;
     private final transient TypeDescriptor<OutputT> outputType;
     private final transient @Nullable TypeDescriptor<FailureT> failureType;
-    private final transient Object originalFnForDisplayData;
-    private final @Nullable Contextful<Fn<InputT, Iterable<OutputT>>> fn;
+    private final @Nullable Object fn;
     private final @Nullable ProcessFunction<ExceptionElement<InputT>, 
FailureT> exceptionHandler;
 
     FlatMapWithFailures(
-        @Nullable Contextful<Fn<InputT, Iterable<OutputT>>> fn,
-        Object originalFnForDisplayData,
+        Object fn,
         TypeDescriptor<InputT> inputType,
         TypeDescriptor<OutputT> outputType,
         @Nullable ProcessFunction<ExceptionElement<InputT>, FailureT> 
exceptionHandler,
         @Nullable TypeDescriptor<FailureT> failureType) {
       this.fn = fn;
-      this.originalFnForDisplayData = originalFnForDisplayData;
       this.inputType = inputType;
       this.outputType = outputType;
       this.exceptionHandler = exceptionHandler;
@@ -291,29 +310,101 @@ public class FlatMapElements<InputT, OutputT>
      */
     public FlatMapWithFailures<InputT, OutputT, FailureT> exceptionsVia(
         ProcessFunction<ExceptionElement<InputT>, FailureT> exceptionHandler) {
-      return new FlatMapWithFailures<>(
-          fn, originalFnForDisplayData, inputType, outputType, 
exceptionHandler, failureType);
+      return new FlatMapWithFailures<>(fn, inputType, outputType, 
exceptionHandler, failureType);
     }
 
     @Override
     public WithFailures.Result<PCollection<OutputT>, FailureT> 
expand(PCollection<InputT> input) {
       checkArgument(exceptionHandler != null, ".exceptionsVia() is required");
-      MapFn doFn = new MapFn();
-      PCollectionTuple tuple =
-          input.apply(
-              FlatMapWithFailures.class.getSimpleName(),
-              ParDo.of(doFn)
-                  .withOutputTags(doFn.outputTag, 
TupleTagList.of(doFn.failureTag))
-                  .withSideInputs(this.fn.getRequirements().getSideInputs()));
+      MapWithFailuresDoFn doFn;
+      PCollectionTuple tuple;
+      if (fn instanceof Contextful) {
+        doFn =
+            new MapWithFailuresDoFn() {
+              @ProcessElement
+              public void processElement(@Element InputT element, 
ProcessContext c)
+                  throws Exception {
+                boolean exceptionWasThrown = false;
+                Iterable<OutputT> res = null;
+                try {
+                  res =
+                      ((Contextful<Fn<InputT, Iterable<OutputT>>>) fn)
+                          .getClosure()
+                          .apply(element, Fn.Context.wrapProcessContext(c));
+                } catch (Exception e) {
+                  exceptionWasThrown = true;
+                  ExceptionElement<InputT> exceptionElement = 
ExceptionElement.of(element, e);
+                  c.output(failureTag, 
exceptionHandler.apply(exceptionElement));
+                }
+                // We make sure our outputs occur outside the try block, since 
runners may implement
+                // fusion by having output() directly call the body of another 
DoFn, potentially
+                // catching
+                // exceptions unrelated to this transform.
+                if (!exceptionWasThrown) {
+                  for (OutputT output : res) {
+                    c.output(output);
+                  }
+                }
+              }
+            };
+        tuple =
+            input.apply(
+                FlatMapWithFailures.class.getSimpleName(),
+                ParDo.of(doFn)
+                    .withOutputTags(doFn.outputTag, 
TupleTagList.of(doFn.failureTag))
+                    .withSideInputs(
+                        ((Contextful<Fn<InputT, Iterable<OutputT>>>) fn)
+                            .getRequirements()
+                            .getSideInputs()));
+      } else if (fn instanceof ProcessFunction) {
+        doFn =
+            new MapWithFailuresDoFn() {
+              @ProcessElement
+              public void processElement(@Element InputT element, 
ProcessContext c)
+                  throws Exception {
+                boolean exceptionWasThrown = false;
+                Iterable<OutputT> res = null;
+                try {
+                  res = ((ProcessFunction<InputT, Iterable<OutputT>>) 
fn).apply(element);
+                } catch (Exception e) {
+                  exceptionWasThrown = true;
+                  ExceptionElement<InputT> exceptionElement = 
ExceptionElement.of(element, e);
+                  c.output(failureTag, 
exceptionHandler.apply(exceptionElement));
+                }
+                // We make sure our outputs occur outside the try block, since 
runners may implement
+                // fusion by having output() directly call the body of another 
DoFn, potentially
+                // catching
+                // exceptions unrelated to this transform.
+                if (!exceptionWasThrown) {
+                  for (OutputT output : res) {
+                    c.output(output);
+                  }
+                }
+              }
+            };
+        tuple =
+            input.apply(
+                FlatMapWithFailures.class.getSimpleName(),
+                ParDo.of(doFn).withOutputTags(doFn.outputTag, 
TupleTagList.of(doFn.failureTag)));
+      } else {
+        throw new IllegalArgumentException(
+            String.format("Unknown type of fn class %s", fn.getClass()));
+      }
       return WithFailures.Result.of(tuple, doFn.outputTag, doFn.failureTag);
     }
 
     @Override
     public void populateDisplayData(DisplayData.Builder builder) {
       super.populateDisplayData(builder);
-      builder.add(DisplayData.item("class", 
originalFnForDisplayData.getClass()));
-      if (originalFnForDisplayData instanceof HasDisplayData) {
-        builder.include("fn", (HasDisplayData) originalFnForDisplayData);
+      Object fnForDisplayData;
+      if (fn instanceof Contextful) {
+        fnForDisplayData = ((Contextful<?>) fn).getClosure();
+      } else {
+        fnForDisplayData = fn;
+      }
+      builder.add(DisplayData.item("class", fnForDisplayData.getClass()));
+      if (fnForDisplayData instanceof HasDisplayData) {
+        builder.include("fn", (HasDisplayData) fnForDisplayData);
       }
       builder.add(DisplayData.item("exceptionHandler.class", 
exceptionHandler.getClass()));
       if (exceptionHandler instanceof HasDisplayData) {
@@ -330,33 +421,11 @@ public class FlatMapElements<InputT, OutputT>
     }
 
     /** A DoFn implementation that handles exceptions and outputs a secondary 
failure collection. */
-    private class MapFn extends DoFn<InputT, OutputT> {
+    private abstract class MapWithFailuresDoFn extends DoFn<InputT, OutputT> {
 
       final TupleTag<OutputT> outputTag = new TupleTag<OutputT>() {};
       final TupleTag<FailureT> failureTag = new FailureTag();
 
-      @ProcessElement
-      public void processElement(@Element InputT element, MultiOutputReceiver 
r, ProcessContext c)
-          throws Exception {
-        boolean exceptionWasThrown = false;
-        Iterable<OutputT> res = null;
-        try {
-          res = fn.getClosure().apply(c.element(), 
Fn.Context.wrapProcessContext(c));
-        } catch (Exception e) {
-          exceptionWasThrown = true;
-          ExceptionElement<InputT> exceptionElement = 
ExceptionElement.of(element, e);
-          r.get(failureTag).output(exceptionHandler.apply(exceptionElement));
-        }
-        // We make sure our outputs occur outside the try block, since runners 
may implement
-        // fusion by having output() directly call the body of another DoFn, 
potentially catching
-        // exceptions unrelated to this transform.
-        if (!exceptionWasThrown) {
-          for (OutputT output : res) {
-            r.get(outputTag).output(output);
-          }
-        }
-      }
-
       @Override
       public void populateDisplayData(DisplayData.Builder builder) {
         builder.delegate(FlatMapWithFailures.this);
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java
index e7de5a14754..2159ff44367 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java
@@ -41,18 +41,17 @@ import org.checkerframework.checker.nullness.qual.Nullable;
 })
 public class MapElements<InputT, OutputT>
     extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> {
+
   private final transient @Nullable TypeDescriptor<InputT> inputType;
-  private final transient @Nullable TypeDescriptor<OutputT> outputType;
-  private final transient @Nullable Object originalFnForDisplayData;
-  private final @Nullable Contextful<Fn<InputT, OutputT>> fn;
+  private final transient TypeDescriptor<OutputT> outputType;
+
+  private final @Nullable Object fn;
 
   private MapElements(
-      @Nullable Contextful<Fn<InputT, OutputT>> fn,
-      @Nullable Object originalFnForDisplayData,
+      @Nullable Object fn,
       @Nullable TypeDescriptor<InputT> inputType,
       TypeDescriptor<OutputT> outputType) {
     this.fn = fn;
-    this.originalFnForDisplayData = originalFnForDisplayData;
     this.inputType = inputType;
     this.outputType = outputType;
   }
@@ -80,8 +79,7 @@ public class MapElements<InputT, OutputT>
    */
   public static <InputT, OutputT> MapElements<InputT, OutputT> via(
       final InferableFunction<InputT, OutputT> fn) {
-    return new MapElements<>(
-        Contextful.fn(fn), fn, fn.getInputTypeDescriptor(), 
fn.getOutputTypeDescriptor());
+    return new MapElements<>(fn, fn.getInputTypeDescriptor(), 
fn.getOutputTypeDescriptor());
   }
 
   /** Binary compatibility adapter for {@link #via(InferableFunction)}. */
@@ -95,7 +93,7 @@ public class MapElements<InputT, OutputT>
    * but the mapping function yet to be specified using {@link 
#via(ProcessFunction)}.
    */
   public static <OutputT> MapElements<?, OutputT> into(final 
TypeDescriptor<OutputT> outputType) {
-    return new MapElements<>(null, null, null, outputType);
+    return new MapElements<>(null, null, outputType);
   }
 
   /**
@@ -112,7 +110,7 @@ public class MapElements<InputT, OutputT>
    * }</pre>
    */
   public <NewInputT> MapElements<NewInputT, OutputT> 
via(ProcessFunction<NewInputT, OutputT> fn) {
-    return new MapElements<>(Contextful.fn(fn), fn, 
TypeDescriptors.inputOf(fn), outputType);
+    return new MapElements<>(fn, TypeDescriptors.inputOf(fn), outputType);
   }
 
   /** Binary compatibility adapter for {@link #via(ProcessFunction)}. */
@@ -124,56 +122,81 @@ public class MapElements<InputT, OutputT>
   /** Like {@link #via(ProcessFunction)}, but supports access to context, such 
as side inputs. */
   @Experimental(Kind.CONTEXTFUL)
   public <NewInputT> MapElements<NewInputT, OutputT> 
via(Contextful<Fn<NewInputT, OutputT>> fn) {
-    return new MapElements<>(
-        fn, fn.getClosure(), TypeDescriptors.inputOf(fn.getClosure()), 
outputType);
+    return new MapElements<>(fn, TypeDescriptors.inputOf(fn.getClosure()), 
outputType);
   }
 
   @Override
   public PCollection<OutputT> expand(PCollection<? extends InputT> input) {
     checkNotNull(fn, "Must specify a function on MapElements using .via()");
-    return input.apply(
-        "Map",
-        ParDo.of(
-                new DoFn<InputT, OutputT>() {
-                  @ProcessElement
-                  public void processElement(
-                      @Element InputT element, OutputReceiver<OutputT> 
receiver, ProcessContext c)
-                      throws Exception {
-                    receiver.output(
-                        fn.getClosure().apply(element, 
Fn.Context.wrapProcessContext(c)));
-                  }
+    if (fn instanceof Contextful) {
+      return input.apply(
+          "Map",
+          ParDo.of(
+                  new MapDoFn() {
+                    @ProcessElement
+                    public void processElement(ProcessContext c) throws 
Exception {
+                      c.output(
+                          ((Contextful<Fn<InputT, OutputT>>) fn)
+                              .getClosure()
+                              .apply(c.element(), 
Fn.Context.wrapProcessContext(c)));
+                    }
+                  })
+              .withSideInputs(
+                  ((Contextful<Fn<InputT, OutputT>>) 
fn).getRequirements().getSideInputs()));
+    } else if (fn instanceof ProcessFunction) {
+      return input.apply(
+          "Map",
+          ParDo.of(
+              new MapDoFn() {
+                @ProcessElement
+                public void processElement(
+                    @Element InputT element, OutputReceiver<OutputT> receiver) 
throws Exception {
+                  receiver.output(((ProcessFunction<InputT, OutputT>) 
fn).apply(element));
+                }
+              }));
+    } else {
+      throw new IllegalArgumentException(
+          String.format("Unknown type of fn class %s", fn.getClass()));
+    }
+  }
 
-                  @Override
-                  public void populateDisplayData(DisplayData.Builder builder) 
{
-                    builder.delegate(MapElements.this);
-                  }
+  /** A DoFn implementation that handles a trivial map call. */
+  private abstract class MapDoFn extends DoFn<InputT, OutputT> {
+    @Override
+    public void populateDisplayData(DisplayData.Builder builder) {
+      builder.delegate(MapElements.this);
+    }
 
-                  @Override
-                  public TypeDescriptor<InputT> getInputTypeDescriptor() {
-                    return inputType;
-                  }
+    @Override
+    public TypeDescriptor<InputT> getInputTypeDescriptor() {
+      return inputType;
+    }
 
-                  @Override
-                  public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
-                    checkState(
-                        outputType != null,
-                        "%s output type descriptor was null; "
-                            + "this probably means that 
getOutputTypeDescriptor() was called after "
-                            + "serialization/deserialization, but it is only 
available prior to "
-                            + "serialization, for constructing a pipeline and 
inferring coders",
-                        MapElements.class.getSimpleName());
-                    return outputType;
-                  }
-                })
-            .withSideInputs(fn.getRequirements().getSideInputs()));
+    @Override
+    public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+      checkState(
+          outputType != null,
+          "%s output type descriptor was null; "
+              + "this probably means that getOutputTypeDescriptor() was called 
after "
+              + "serialization/deserialization, but it is only available prior 
to "
+              + "serialization, for constructing a pipeline and inferring 
coders",
+          MapElements.class.getSimpleName());
+      return outputType;
+    }
   }
 
   @Override
   public void populateDisplayData(DisplayData.Builder builder) {
     super.populateDisplayData(builder);
-    builder.add(DisplayData.item("class", 
originalFnForDisplayData.getClass()));
-    if (originalFnForDisplayData instanceof HasDisplayData) {
-      builder.include("fn", (HasDisplayData) originalFnForDisplayData);
+    Object fnForDisplayData;
+    if (fn instanceof Contextful) {
+      fnForDisplayData = ((Contextful<?>) fn).getClosure();
+    } else {
+      fnForDisplayData = fn;
+    }
+    builder.add(DisplayData.item("class", fnForDisplayData.getClass()));
+    if (fnForDisplayData instanceof HasDisplayData) {
+      builder.include("fn", (HasDisplayData) fnForDisplayData);
     }
   }
 
@@ -188,8 +211,7 @@ public class MapElements<InputT, OutputT>
   @Experimental(Kind.WITH_EXCEPTIONS)
   public <NewFailureT> MapWithFailures<InputT, OutputT, NewFailureT> 
exceptionsInto(
       TypeDescriptor<NewFailureT> failureTypeDescriptor) {
-    return new MapWithFailures<>(
-        fn, originalFnForDisplayData, inputType, outputType, null, 
failureTypeDescriptor);
+    return new MapWithFailures<>(fn, inputType, outputType, null, 
failureTypeDescriptor);
   }
 
   /**
@@ -219,12 +241,7 @@ public class MapElements<InputT, OutputT>
   public <FailureT> MapWithFailures<InputT, OutputT, FailureT> exceptionsVia(
       InferableFunction<ExceptionElement<InputT>, FailureT> exceptionHandler) {
     return new MapWithFailures<>(
-        fn,
-        originalFnForDisplayData,
-        inputType,
-        outputType,
-        exceptionHandler,
-        exceptionHandler.getOutputTypeDescriptor());
+        fn, inputType, outputType, exceptionHandler, 
exceptionHandler.getOutputTypeDescriptor());
   }
 
   /** A {@code PTransform} that adds exception handling to {@link 
MapElements}. */
@@ -235,19 +252,16 @@ public class MapElements<InputT, OutputT>
     private final transient TypeDescriptor<InputT> inputType;
     private final transient TypeDescriptor<OutputT> outputType;
     private final transient @Nullable TypeDescriptor<FailureT> failureType;
-    private final transient Object originalFnForDisplayData;
-    private final Contextful<Fn<InputT, OutputT>> fn;
+    private final Object fn;
     private final @Nullable ProcessFunction<ExceptionElement<InputT>, 
FailureT> exceptionHandler;
 
     MapWithFailures(
-        Contextful<Fn<InputT, OutputT>> fn,
-        Object originalFnForDisplayData,
+        Object fn,
         TypeDescriptor<InputT> inputType,
         TypeDescriptor<OutputT> outputType,
         @Nullable ProcessFunction<ExceptionElement<InputT>, FailureT> 
exceptionHandler,
         @Nullable TypeDescriptor<FailureT> failureType) {
       this.fn = fn;
-      this.originalFnForDisplayData = originalFnForDisplayData;
       this.inputType = inputType;
       this.outputType = outputType;
       this.exceptionHandler = exceptionHandler;
@@ -274,29 +288,98 @@ public class MapElements<InputT, OutputT>
      */
     public MapWithFailures<InputT, OutputT, FailureT> exceptionsVia(
         ProcessFunction<ExceptionElement<InputT>, FailureT> exceptionHandler) {
-      return new MapWithFailures<>(
-          fn, originalFnForDisplayData, inputType, outputType, 
exceptionHandler, failureType);
+      return new MapWithFailures<>(fn, inputType, outputType, 
exceptionHandler, failureType);
     }
 
     @Override
     public WithFailures.Result<PCollection<OutputT>, FailureT> 
expand(PCollection<InputT> input) {
       checkArgument(exceptionHandler != null, ".exceptionsVia() is required");
-      MapFn doFn = new MapFn();
-      PCollectionTuple tuple =
-          input.apply(
-              MapWithFailures.class.getSimpleName(),
-              ParDo.of(doFn)
-                  .withOutputTags(doFn.outputTag, 
TupleTagList.of(doFn.failureTag))
-                  .withSideInputs(this.fn.getRequirements().getSideInputs()));
+      MapWithFailuresDoFn doFn;
+      PCollectionTuple tuple;
+      if (fn instanceof Contextful) {
+        doFn =
+            new MapWithFailuresDoFn() {
+              @ProcessElement
+              public void processElement(@Element InputT element, 
ProcessContext c)
+                  throws Exception {
+                boolean exceptionWasThrown = false;
+                OutputT result = null;
+                try {
+                  result =
+                      ((Contextful<Fn<InputT, OutputT>>) fn)
+                          .getClosure()
+                          .apply(element, Fn.Context.wrapProcessContext(c));
+                } catch (Exception e) {
+                  exceptionWasThrown = true;
+                  ExceptionElement<InputT> exceptionElement = 
ExceptionElement.of(element, e);
+                  c.output(failureTag, 
exceptionHandler.apply(exceptionElement));
+                }
+                // We make sure our output occurs outside the try block, since 
runners may implement
+                // fusion by having output() directly call the body of another 
DoFn, potentially
+                // catching
+                // exceptions unrelated to this transform.
+                if (!exceptionWasThrown) {
+                  c.output(result);
+                }
+              }
+            };
+        tuple =
+            input.apply(
+                MapWithFailures.class.getSimpleName(),
+                ParDo.of(doFn)
+                    .withOutputTags(doFn.outputTag, 
TupleTagList.of(doFn.failureTag))
+                    .withSideInputs(
+                        ((Contextful<Fn<InputT, OutputT>>) 
fn).getRequirements().getSideInputs()));
+
+      } else if (fn instanceof ProcessFunction) {
+        ProcessFunction<InputT, OutputT> closure = (ProcessFunction<InputT, 
OutputT>) fn;
+        doFn =
+            new MapWithFailuresDoFn() {
+              @ProcessElement
+              public void processElement(@Element InputT element, 
ProcessContext c)
+                  throws Exception {
+                boolean exceptionWasThrown = false;
+                OutputT result;
+                try {
+                  result = closure.apply(element);
+                } catch (Exception e) {
+                  result = null;
+                  exceptionWasThrown = true;
+                  ExceptionElement<InputT> exceptionElement = 
ExceptionElement.of(element, e);
+                  c.output(failureTag, 
exceptionHandler.apply(exceptionElement));
+                }
+                // We make sure our output occurs outside the try block, since 
runners may implement
+                // fusion by having output() directly call the body of another 
DoFn, potentially
+                // catching
+                // exceptions unrelated to this transform.
+                if (!exceptionWasThrown) {
+                  c.output(result);
+                }
+              }
+            };
+        tuple =
+            input.apply(
+                MapWithFailures.class.getSimpleName(),
+                ParDo.of(doFn).withOutputTags(doFn.outputTag, 
TupleTagList.of(doFn.failureTag)));
+      } else {
+        throw new IllegalArgumentException(
+            String.format("Unknown type of fn class %s", fn.getClass()));
+      }
       return WithFailures.Result.of(tuple, doFn.outputTag, doFn.failureTag);
     }
 
     @Override
     public void populateDisplayData(DisplayData.Builder builder) {
       super.populateDisplayData(builder);
-      builder.add(DisplayData.item("class", 
originalFnForDisplayData.getClass()));
-      if (originalFnForDisplayData instanceof HasDisplayData) {
-        builder.include("fn", (HasDisplayData) originalFnForDisplayData);
+      Object fnForDisplayData;
+      if (fn instanceof Contextful) {
+        fnForDisplayData = ((Contextful<?>) fn).getClosure();
+      } else {
+        fnForDisplayData = fn;
+      }
+      builder.add(DisplayData.item("class", fnForDisplayData.getClass()));
+      if (fnForDisplayData instanceof HasDisplayData) {
+        builder.include("fn", (HasDisplayData) fnForDisplayData);
       }
       builder.add(DisplayData.item("exceptionHandler.class", 
exceptionHandler.getClass()));
       if (exceptionHandler instanceof HasDisplayData) {
@@ -313,31 +396,10 @@ public class MapElements<InputT, OutputT>
     }
 
     /** A DoFn implementation that handles exceptions and outputs a secondary 
failure collection. */
-    private class MapFn extends DoFn<InputT, OutputT> {
-
+    private abstract class MapWithFailuresDoFn extends DoFn<InputT, OutputT> {
       final TupleTag<OutputT> outputTag = new TupleTag<OutputT>() {};
       final TupleTag<FailureT> failureTag = new FailureTag();
 
-      @ProcessElement
-      public void processElement(@Element InputT element, MultiOutputReceiver 
r, ProcessContext c)
-          throws Exception {
-        boolean exceptionWasThrown = false;
-        OutputT result = null;
-        try {
-          result = fn.getClosure().apply(c.element(), 
Fn.Context.wrapProcessContext(c));
-        } catch (Exception e) {
-          exceptionWasThrown = true;
-          ExceptionElement<InputT> exceptionElement = 
ExceptionElement.of(element, e);
-          r.get(failureTag).output(exceptionHandler.apply(exceptionElement));
-        }
-        // We make sure our output occurs outside the try block, since runners 
may implement
-        // fusion by having output() directly call the body of another DoFn, 
potentially catching
-        // exceptions unrelated to this transform.
-        if (!exceptionWasThrown) {
-          r.get(outputTag).output(result);
-        }
-      }
-
       @Override
       public void populateDisplayData(DisplayData.Builder builder) {
         builder.delegate(MapWithFailures.this);


Reply via email to