This is an automated email from the ASF dual-hosted git repository. lcwik pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new c0de849c1ce [#24024] Stop wrapping light weight functions with Contextful as they add a lot of overhead for functions that are meant to do almost no work. (#24025) c0de849c1ce is described below commit c0de849c1ce9b56e70a99863ff6f9d57fbbd7c0c Author: Luke Cwik <lc...@google.com> AuthorDate: Thu Dec 8 12:48:49 2022 -0800 [#24024] Stop wrapping light weight functions with Contextful as they add a lot of overhead for functions that are meant to do almost no work. (#24025) * Stop wrapping light weight functions with Contextful as they add a lot of overhead for functions that are meant to do almost no work. Fixes #24024 * Fix Spark test expectations --- .../runners/spark/SparkRunnerDebuggerTest.java | 11 +- .../beam/sdk/transforms/FlatMapElements.java | 257 +++++++++++++-------- .../apache/beam/sdk/transforms/MapElements.java | 250 ++++++++++++-------- 3 files changed, 325 insertions(+), 193 deletions(-) diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java index 843d303bda8..157b3cb946e 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java @@ -86,14 +86,15 @@ public class SparkRunnerDebuggerTest { "sparkContext.<readFrom(org.apache.beam.sdk.transforms.Create$Values$CreateSource)>()\n" + "_.mapPartitions(" + "new org.apache.beam.runners.spark.examples.WordCount$ExtractWordsFn())\n" - + "_.mapPartitions(new org.apache.beam.sdk.transforms.Contextful())\n" + + "_.mapPartitions(new org.apache.beam.sdk.transforms.Count$PerElement$1())\n" + "_.combineByKey(..., new org.apache.beam.sdk.transforms.Count$CountFn(), ...)\n" + "_.groupByKey()\n" + "_.map(new org.apache.beam.sdk.transforms.Sum$SumLongFn())\n" - + "_.mapPartitions(new org.apache.beam.sdk.transforms.Contextful())\n" + + "_.mapPartitions(" + + "new org.apache.beam.runners.spark.SparkRunnerDebuggerTest$PlusOne())\n" + "sparkContext.union(...)\n" + "_.mapPartitions(" - + "new org.apache.beam.sdk.transforms.Contextful())\n" + + "new org.apache.beam.runners.spark.examples.WordCount$FormatAsTextFn())\n" + "_.<org.apache.beam.sdk.io.TextIO$Write>"; SparkRunnerDebugger.DebugSparkPipelineResult result = @@ -142,11 +143,11 @@ public class SparkRunnerDebuggerTest { + "_.map(new org.apache.beam.sdk.transforms.windowing.FixedWindows())\n" + "_.mapPartitions(new org.apache.beam.runners.spark." + "SparkRunnerDebuggerTest$FormatKVFn())\n" - + "_.mapPartitions(new org.apache.beam.sdk.transforms.Contextful())\n" + + "_.mapPartitions(new org.apache.beam.sdk.transforms.Distinct$1())\n" + "_.groupByKey()\n" + "_.map(new org.apache.beam.sdk.transforms.Combine$IterableCombineFn())\n" + "_.mapPartitions(new org.apache.beam.sdk.transforms.Distinct$3())\n" - + "_.mapPartitions(new org.apache.beam.sdk.transforms.Contextful())\n" + + "_.mapPartitions(new org.apache.beam.sdk.transforms.WithKeys$1())\n" + "_.<org.apache.beam.sdk.io.kafka.AutoValue_KafkaIO_Write>"; SparkRunnerDebugger.DebugSparkPipelineResult result = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java index 19e6f6465b5..9a87aba766f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java @@ -45,16 +45,13 @@ public class FlatMapElements<InputT, OutputT> extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { private final transient @Nullable TypeDescriptor<InputT> inputType; private final transient @Nullable TypeDescriptor<OutputT> outputType; - private final transient @Nullable Object originalFnForDisplayData; - private final @Nullable Contextful<Fn<InputT, Iterable<OutputT>>> fn; + private final @Nullable Object fn; private FlatMapElements( - @Nullable Contextful<Fn<InputT, Iterable<OutputT>>> fn, - @Nullable Object originalFnForDisplayData, + @Nullable Object fn, @Nullable TypeDescriptor<InputT> inputType, TypeDescriptor<OutputT> outputType) { this.fn = fn; - this.originalFnForDisplayData = originalFnForDisplayData; this.inputType = inputType; this.outputType = outputType; } @@ -83,14 +80,13 @@ public class FlatMapElements<InputT, OutputT> */ public static <InputT, OutputT> FlatMapElements<InputT, OutputT> via( InferableFunction<? super InputT, ? extends Iterable<OutputT>> fn) { - Contextful<Fn<InputT, Iterable<OutputT>>> wrapped = (Contextful) Contextful.fn(fn); TypeDescriptor<OutputT> outputType = TypeDescriptors.extractFromTypeParameters( (TypeDescriptor<Iterable<OutputT>>) fn.getOutputTypeDescriptor(), Iterable.class, new TypeDescriptors.TypeVariableExtractor<Iterable<OutputT>, OutputT>() {}); TypeDescriptor<InputT> inputType = (TypeDescriptor<InputT>) fn.getInputTypeDescriptor(); - return new FlatMapElements<>(wrapped, fn, inputType, outputType); + return new FlatMapElements<>(fn, inputType, outputType); } /** Binary compatibility adapter for {@link #via(ProcessFunction)}. */ @@ -105,7 +101,7 @@ public class FlatMapElements<InputT, OutputT> */ public static <OutputT> FlatMapElements<?, OutputT> into( final TypeDescriptor<OutputT> outputType) { - return new FlatMapElements<>(null, null, null, outputType); + return new FlatMapElements<>(null, null, outputType); } /** @@ -123,8 +119,7 @@ public class FlatMapElements<InputT, OutputT> */ public <NewInputT> FlatMapElements<NewInputT, OutputT> via( ProcessFunction<NewInputT, ? extends Iterable<OutputT>> fn) { - return new FlatMapElements<>( - (Contextful) Contextful.fn(fn), fn, TypeDescriptors.inputOf(fn), outputType); + return new FlatMapElements<>(fn, TypeDescriptors.inputOf(fn), outputType); } /** Binary compatibility adapter for {@link #via(ProcessFunction)}. */ @@ -137,57 +132,90 @@ public class FlatMapElements<InputT, OutputT> @Experimental(Kind.CONTEXTFUL) public <NewInputT> FlatMapElements<NewInputT, OutputT> via( Contextful<Fn<NewInputT, Iterable<OutputT>>> fn) { - return new FlatMapElements<>( - fn, fn.getClosure(), TypeDescriptors.inputOf(fn.getClosure()), outputType); + return new FlatMapElements<>(fn, TypeDescriptors.inputOf(fn.getClosure()), outputType); } @Override public PCollection<OutputT> expand(PCollection<? extends InputT> input) { checkArgument(fn != null, ".via() is required"); - return input.apply( - "FlatMap", - ParDo.of( - new DoFn<InputT, OutputT>() { - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - Iterable<OutputT> res = - fn.getClosure().apply(c.element(), Fn.Context.wrapProcessContext(c)); - for (OutputT output : res) { - c.output(output); + if (fn instanceof Contextful) { + return input.apply( + "FlatMap", + ParDo.of( + new FlatMapDoFn() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + Iterable<OutputT> res = + ((Contextful<Fn<InputT, Iterable<OutputT>>>) fn) + .getClosure() + .apply(c.element(), Fn.Context.wrapProcessContext(c)); + for (OutputT output : res) { + c.output(output); + } } + }) + .withSideInputs( + ((Contextful<Fn<InputT, Iterable<OutputT>>>) fn) + .getRequirements() + .getSideInputs())); + } else if (fn instanceof ProcessFunction) { + return input.apply( + "FlatMap", + ParDo.of( + new FlatMapDoFn() { + @ProcessElement + public void processElement( + @Element InputT element, OutputReceiver<OutputT> receiver) throws Exception { + Iterable<OutputT> res = + ((ProcessFunction<InputT, Iterable<OutputT>>) fn).apply(element); + for (OutputT output : res) { + receiver.output(output); } + } + })); + } else { + throw new IllegalArgumentException( + String.format("Unknown type of fn class %s", fn.getClass())); + } + } - @Override - public TypeDescriptor<InputT> getInputTypeDescriptor() { - return inputType; - } + private abstract class FlatMapDoFn extends DoFn<InputT, OutputT> { - @Override - public TypeDescriptor<OutputT> getOutputTypeDescriptor() { - checkState( - outputType != null, - "%s output type descriptor was null; " - + "this probably means that getOutputTypeDescriptor() was called after " - + "serialization/deserialization, but it is only available prior to " - + "serialization, for constructing a pipeline and inferring coders", - FlatMapElements.class.getSimpleName()); - return outputType; - } + @Override + public TypeDescriptor<InputT> getInputTypeDescriptor() { + return inputType; + } - @Override - public void populateDisplayData(DisplayData.Builder builder) { - builder.delegate(FlatMapElements.this); - } - }) - .withSideInputs(fn.getRequirements().getSideInputs())); + @Override + public TypeDescriptor<OutputT> getOutputTypeDescriptor() { + checkState( + outputType != null, + "%s output type descriptor was null; " + + "this probably means that getOutputTypeDescriptor() was called after " + + "serialization/deserialization, but it is only available prior to " + + "serialization, for constructing a pipeline and inferring coders", + FlatMapElements.class.getSimpleName()); + return outputType; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.delegate(FlatMapElements.this); + } } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); - builder.add(DisplayData.item("class", originalFnForDisplayData.getClass())); - if (originalFnForDisplayData instanceof HasDisplayData) { - builder.include("fn", (HasDisplayData) originalFnForDisplayData); + Object fnForDisplayData; + if (fn instanceof Contextful) { + fnForDisplayData = ((Contextful<Fn<InputT, OutputT>>) fn).getClosure(); + } else { + fnForDisplayData = fn; + } + builder.add(DisplayData.item("class", fnForDisplayData.getClass())); + if (fnForDisplayData instanceof HasDisplayData) { + builder.include("fn", (HasDisplayData) fnForDisplayData); } } @@ -203,8 +231,7 @@ public class FlatMapElements<InputT, OutputT> @Experimental(Kind.WITH_EXCEPTIONS) public <NewFailureT> FlatMapWithFailures<InputT, OutputT, NewFailureT> exceptionsInto( TypeDescriptor<NewFailureT> failureTypeDescriptor) { - return new FlatMapWithFailures<>( - fn, originalFnForDisplayData, inputType, outputType, null, failureTypeDescriptor); + return new FlatMapWithFailures<>(fn, inputType, outputType, null, failureTypeDescriptor); } /** @@ -235,12 +262,7 @@ public class FlatMapElements<InputT, OutputT> public <FailureT> FlatMapWithFailures<InputT, OutputT, FailureT> exceptionsVia( InferableFunction<ExceptionElement<InputT>, FailureT> exceptionHandler) { return new FlatMapWithFailures<>( - fn, - originalFnForDisplayData, - inputType, - outputType, - exceptionHandler, - exceptionHandler.getOutputTypeDescriptor()); + fn, inputType, outputType, exceptionHandler, exceptionHandler.getOutputTypeDescriptor()); } /** A {@code PTransform} that adds exception handling to {@link FlatMapElements}. */ @@ -251,19 +273,16 @@ public class FlatMapElements<InputT, OutputT> private final transient TypeDescriptor<InputT> inputType; private final transient TypeDescriptor<OutputT> outputType; private final transient @Nullable TypeDescriptor<FailureT> failureType; - private final transient Object originalFnForDisplayData; - private final @Nullable Contextful<Fn<InputT, Iterable<OutputT>>> fn; + private final @Nullable Object fn; private final @Nullable ProcessFunction<ExceptionElement<InputT>, FailureT> exceptionHandler; FlatMapWithFailures( - @Nullable Contextful<Fn<InputT, Iterable<OutputT>>> fn, - Object originalFnForDisplayData, + Object fn, TypeDescriptor<InputT> inputType, TypeDescriptor<OutputT> outputType, @Nullable ProcessFunction<ExceptionElement<InputT>, FailureT> exceptionHandler, @Nullable TypeDescriptor<FailureT> failureType) { this.fn = fn; - this.originalFnForDisplayData = originalFnForDisplayData; this.inputType = inputType; this.outputType = outputType; this.exceptionHandler = exceptionHandler; @@ -291,29 +310,101 @@ public class FlatMapElements<InputT, OutputT> */ public FlatMapWithFailures<InputT, OutputT, FailureT> exceptionsVia( ProcessFunction<ExceptionElement<InputT>, FailureT> exceptionHandler) { - return new FlatMapWithFailures<>( - fn, originalFnForDisplayData, inputType, outputType, exceptionHandler, failureType); + return new FlatMapWithFailures<>(fn, inputType, outputType, exceptionHandler, failureType); } @Override public WithFailures.Result<PCollection<OutputT>, FailureT> expand(PCollection<InputT> input) { checkArgument(exceptionHandler != null, ".exceptionsVia() is required"); - MapFn doFn = new MapFn(); - PCollectionTuple tuple = - input.apply( - FlatMapWithFailures.class.getSimpleName(), - ParDo.of(doFn) - .withOutputTags(doFn.outputTag, TupleTagList.of(doFn.failureTag)) - .withSideInputs(this.fn.getRequirements().getSideInputs())); + MapWithFailuresDoFn doFn; + PCollectionTuple tuple; + if (fn instanceof Contextful) { + doFn = + new MapWithFailuresDoFn() { + @ProcessElement + public void processElement(@Element InputT element, ProcessContext c) + throws Exception { + boolean exceptionWasThrown = false; + Iterable<OutputT> res = null; + try { + res = + ((Contextful<Fn<InputT, Iterable<OutputT>>>) fn) + .getClosure() + .apply(element, Fn.Context.wrapProcessContext(c)); + } catch (Exception e) { + exceptionWasThrown = true; + ExceptionElement<InputT> exceptionElement = ExceptionElement.of(element, e); + c.output(failureTag, exceptionHandler.apply(exceptionElement)); + } + // We make sure our outputs occur outside the try block, since runners may implement + // fusion by having output() directly call the body of another DoFn, potentially + // catching + // exceptions unrelated to this transform. + if (!exceptionWasThrown) { + for (OutputT output : res) { + c.output(output); + } + } + } + }; + tuple = + input.apply( + FlatMapWithFailures.class.getSimpleName(), + ParDo.of(doFn) + .withOutputTags(doFn.outputTag, TupleTagList.of(doFn.failureTag)) + .withSideInputs( + ((Contextful<Fn<InputT, Iterable<OutputT>>>) fn) + .getRequirements() + .getSideInputs())); + } else if (fn instanceof ProcessFunction) { + doFn = + new MapWithFailuresDoFn() { + @ProcessElement + public void processElement(@Element InputT element, ProcessContext c) + throws Exception { + boolean exceptionWasThrown = false; + Iterable<OutputT> res = null; + try { + res = ((ProcessFunction<InputT, Iterable<OutputT>>) fn).apply(element); + } catch (Exception e) { + exceptionWasThrown = true; + ExceptionElement<InputT> exceptionElement = ExceptionElement.of(element, e); + c.output(failureTag, exceptionHandler.apply(exceptionElement)); + } + // We make sure our outputs occur outside the try block, since runners may implement + // fusion by having output() directly call the body of another DoFn, potentially + // catching + // exceptions unrelated to this transform. + if (!exceptionWasThrown) { + for (OutputT output : res) { + c.output(output); + } + } + } + }; + tuple = + input.apply( + FlatMapWithFailures.class.getSimpleName(), + ParDo.of(doFn).withOutputTags(doFn.outputTag, TupleTagList.of(doFn.failureTag))); + } else { + throw new IllegalArgumentException( + String.format("Unknown type of fn class %s", fn.getClass())); + } return WithFailures.Result.of(tuple, doFn.outputTag, doFn.failureTag); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); - builder.add(DisplayData.item("class", originalFnForDisplayData.getClass())); - if (originalFnForDisplayData instanceof HasDisplayData) { - builder.include("fn", (HasDisplayData) originalFnForDisplayData); + Object fnForDisplayData; + if (fn instanceof Contextful) { + fnForDisplayData = ((Contextful<?>) fn).getClosure(); + } else { + fnForDisplayData = fn; + } + builder.add(DisplayData.item("class", fnForDisplayData.getClass())); + if (fnForDisplayData instanceof HasDisplayData) { + builder.include("fn", (HasDisplayData) fnForDisplayData); } builder.add(DisplayData.item("exceptionHandler.class", exceptionHandler.getClass())); if (exceptionHandler instanceof HasDisplayData) { @@ -330,33 +421,11 @@ public class FlatMapElements<InputT, OutputT> } /** A DoFn implementation that handles exceptions and outputs a secondary failure collection. */ - private class MapFn extends DoFn<InputT, OutputT> { + private abstract class MapWithFailuresDoFn extends DoFn<InputT, OutputT> { final TupleTag<OutputT> outputTag = new TupleTag<OutputT>() {}; final TupleTag<FailureT> failureTag = new FailureTag(); - @ProcessElement - public void processElement(@Element InputT element, MultiOutputReceiver r, ProcessContext c) - throws Exception { - boolean exceptionWasThrown = false; - Iterable<OutputT> res = null; - try { - res = fn.getClosure().apply(c.element(), Fn.Context.wrapProcessContext(c)); - } catch (Exception e) { - exceptionWasThrown = true; - ExceptionElement<InputT> exceptionElement = ExceptionElement.of(element, e); - r.get(failureTag).output(exceptionHandler.apply(exceptionElement)); - } - // We make sure our outputs occur outside the try block, since runners may implement - // fusion by having output() directly call the body of another DoFn, potentially catching - // exceptions unrelated to this transform. - if (!exceptionWasThrown) { - for (OutputT output : res) { - r.get(outputTag).output(output); - } - } - } - @Override public void populateDisplayData(DisplayData.Builder builder) { builder.delegate(FlatMapWithFailures.this); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java index e7de5a14754..2159ff44367 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java @@ -41,18 +41,17 @@ import org.checkerframework.checker.nullness.qual.Nullable; }) public class MapElements<InputT, OutputT> extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { + private final transient @Nullable TypeDescriptor<InputT> inputType; - private final transient @Nullable TypeDescriptor<OutputT> outputType; - private final transient @Nullable Object originalFnForDisplayData; - private final @Nullable Contextful<Fn<InputT, OutputT>> fn; + private final transient TypeDescriptor<OutputT> outputType; + + private final @Nullable Object fn; private MapElements( - @Nullable Contextful<Fn<InputT, OutputT>> fn, - @Nullable Object originalFnForDisplayData, + @Nullable Object fn, @Nullable TypeDescriptor<InputT> inputType, TypeDescriptor<OutputT> outputType) { this.fn = fn; - this.originalFnForDisplayData = originalFnForDisplayData; this.inputType = inputType; this.outputType = outputType; } @@ -80,8 +79,7 @@ public class MapElements<InputT, OutputT> */ public static <InputT, OutputT> MapElements<InputT, OutputT> via( final InferableFunction<InputT, OutputT> fn) { - return new MapElements<>( - Contextful.fn(fn), fn, fn.getInputTypeDescriptor(), fn.getOutputTypeDescriptor()); + return new MapElements<>(fn, fn.getInputTypeDescriptor(), fn.getOutputTypeDescriptor()); } /** Binary compatibility adapter for {@link #via(InferableFunction)}. */ @@ -95,7 +93,7 @@ public class MapElements<InputT, OutputT> * but the mapping function yet to be specified using {@link #via(ProcessFunction)}. */ public static <OutputT> MapElements<?, OutputT> into(final TypeDescriptor<OutputT> outputType) { - return new MapElements<>(null, null, null, outputType); + return new MapElements<>(null, null, outputType); } /** @@ -112,7 +110,7 @@ public class MapElements<InputT, OutputT> * }</pre> */ public <NewInputT> MapElements<NewInputT, OutputT> via(ProcessFunction<NewInputT, OutputT> fn) { - return new MapElements<>(Contextful.fn(fn), fn, TypeDescriptors.inputOf(fn), outputType); + return new MapElements<>(fn, TypeDescriptors.inputOf(fn), outputType); } /** Binary compatibility adapter for {@link #via(ProcessFunction)}. */ @@ -124,56 +122,81 @@ public class MapElements<InputT, OutputT> /** Like {@link #via(ProcessFunction)}, but supports access to context, such as side inputs. */ @Experimental(Kind.CONTEXTFUL) public <NewInputT> MapElements<NewInputT, OutputT> via(Contextful<Fn<NewInputT, OutputT>> fn) { - return new MapElements<>( - fn, fn.getClosure(), TypeDescriptors.inputOf(fn.getClosure()), outputType); + return new MapElements<>(fn, TypeDescriptors.inputOf(fn.getClosure()), outputType); } @Override public PCollection<OutputT> expand(PCollection<? extends InputT> input) { checkNotNull(fn, "Must specify a function on MapElements using .via()"); - return input.apply( - "Map", - ParDo.of( - new DoFn<InputT, OutputT>() { - @ProcessElement - public void processElement( - @Element InputT element, OutputReceiver<OutputT> receiver, ProcessContext c) - throws Exception { - receiver.output( - fn.getClosure().apply(element, Fn.Context.wrapProcessContext(c))); - } + if (fn instanceof Contextful) { + return input.apply( + "Map", + ParDo.of( + new MapDoFn() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + c.output( + ((Contextful<Fn<InputT, OutputT>>) fn) + .getClosure() + .apply(c.element(), Fn.Context.wrapProcessContext(c))); + } + }) + .withSideInputs( + ((Contextful<Fn<InputT, OutputT>>) fn).getRequirements().getSideInputs())); + } else if (fn instanceof ProcessFunction) { + return input.apply( + "Map", + ParDo.of( + new MapDoFn() { + @ProcessElement + public void processElement( + @Element InputT element, OutputReceiver<OutputT> receiver) throws Exception { + receiver.output(((ProcessFunction<InputT, OutputT>) fn).apply(element)); + } + })); + } else { + throw new IllegalArgumentException( + String.format("Unknown type of fn class %s", fn.getClass())); + } + } - @Override - public void populateDisplayData(DisplayData.Builder builder) { - builder.delegate(MapElements.this); - } + /** A DoFn implementation that handles a trivial map call. */ + private abstract class MapDoFn extends DoFn<InputT, OutputT> { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.delegate(MapElements.this); + } - @Override - public TypeDescriptor<InputT> getInputTypeDescriptor() { - return inputType; - } + @Override + public TypeDescriptor<InputT> getInputTypeDescriptor() { + return inputType; + } - @Override - public TypeDescriptor<OutputT> getOutputTypeDescriptor() { - checkState( - outputType != null, - "%s output type descriptor was null; " - + "this probably means that getOutputTypeDescriptor() was called after " - + "serialization/deserialization, but it is only available prior to " - + "serialization, for constructing a pipeline and inferring coders", - MapElements.class.getSimpleName()); - return outputType; - } - }) - .withSideInputs(fn.getRequirements().getSideInputs())); + @Override + public TypeDescriptor<OutputT> getOutputTypeDescriptor() { + checkState( + outputType != null, + "%s output type descriptor was null; " + + "this probably means that getOutputTypeDescriptor() was called after " + + "serialization/deserialization, but it is only available prior to " + + "serialization, for constructing a pipeline and inferring coders", + MapElements.class.getSimpleName()); + return outputType; + } } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); - builder.add(DisplayData.item("class", originalFnForDisplayData.getClass())); - if (originalFnForDisplayData instanceof HasDisplayData) { - builder.include("fn", (HasDisplayData) originalFnForDisplayData); + Object fnForDisplayData; + if (fn instanceof Contextful) { + fnForDisplayData = ((Contextful<?>) fn).getClosure(); + } else { + fnForDisplayData = fn; + } + builder.add(DisplayData.item("class", fnForDisplayData.getClass())); + if (fnForDisplayData instanceof HasDisplayData) { + builder.include("fn", (HasDisplayData) fnForDisplayData); } } @@ -188,8 +211,7 @@ public class MapElements<InputT, OutputT> @Experimental(Kind.WITH_EXCEPTIONS) public <NewFailureT> MapWithFailures<InputT, OutputT, NewFailureT> exceptionsInto( TypeDescriptor<NewFailureT> failureTypeDescriptor) { - return new MapWithFailures<>( - fn, originalFnForDisplayData, inputType, outputType, null, failureTypeDescriptor); + return new MapWithFailures<>(fn, inputType, outputType, null, failureTypeDescriptor); } /** @@ -219,12 +241,7 @@ public class MapElements<InputT, OutputT> public <FailureT> MapWithFailures<InputT, OutputT, FailureT> exceptionsVia( InferableFunction<ExceptionElement<InputT>, FailureT> exceptionHandler) { return new MapWithFailures<>( - fn, - originalFnForDisplayData, - inputType, - outputType, - exceptionHandler, - exceptionHandler.getOutputTypeDescriptor()); + fn, inputType, outputType, exceptionHandler, exceptionHandler.getOutputTypeDescriptor()); } /** A {@code PTransform} that adds exception handling to {@link MapElements}. */ @@ -235,19 +252,16 @@ public class MapElements<InputT, OutputT> private final transient TypeDescriptor<InputT> inputType; private final transient TypeDescriptor<OutputT> outputType; private final transient @Nullable TypeDescriptor<FailureT> failureType; - private final transient Object originalFnForDisplayData; - private final Contextful<Fn<InputT, OutputT>> fn; + private final Object fn; private final @Nullable ProcessFunction<ExceptionElement<InputT>, FailureT> exceptionHandler; MapWithFailures( - Contextful<Fn<InputT, OutputT>> fn, - Object originalFnForDisplayData, + Object fn, TypeDescriptor<InputT> inputType, TypeDescriptor<OutputT> outputType, @Nullable ProcessFunction<ExceptionElement<InputT>, FailureT> exceptionHandler, @Nullable TypeDescriptor<FailureT> failureType) { this.fn = fn; - this.originalFnForDisplayData = originalFnForDisplayData; this.inputType = inputType; this.outputType = outputType; this.exceptionHandler = exceptionHandler; @@ -274,29 +288,98 @@ public class MapElements<InputT, OutputT> */ public MapWithFailures<InputT, OutputT, FailureT> exceptionsVia( ProcessFunction<ExceptionElement<InputT>, FailureT> exceptionHandler) { - return new MapWithFailures<>( - fn, originalFnForDisplayData, inputType, outputType, exceptionHandler, failureType); + return new MapWithFailures<>(fn, inputType, outputType, exceptionHandler, failureType); } @Override public WithFailures.Result<PCollection<OutputT>, FailureT> expand(PCollection<InputT> input) { checkArgument(exceptionHandler != null, ".exceptionsVia() is required"); - MapFn doFn = new MapFn(); - PCollectionTuple tuple = - input.apply( - MapWithFailures.class.getSimpleName(), - ParDo.of(doFn) - .withOutputTags(doFn.outputTag, TupleTagList.of(doFn.failureTag)) - .withSideInputs(this.fn.getRequirements().getSideInputs())); + MapWithFailuresDoFn doFn; + PCollectionTuple tuple; + if (fn instanceof Contextful) { + doFn = + new MapWithFailuresDoFn() { + @ProcessElement + public void processElement(@Element InputT element, ProcessContext c) + throws Exception { + boolean exceptionWasThrown = false; + OutputT result = null; + try { + result = + ((Contextful<Fn<InputT, OutputT>>) fn) + .getClosure() + .apply(element, Fn.Context.wrapProcessContext(c)); + } catch (Exception e) { + exceptionWasThrown = true; + ExceptionElement<InputT> exceptionElement = ExceptionElement.of(element, e); + c.output(failureTag, exceptionHandler.apply(exceptionElement)); + } + // We make sure our output occurs outside the try block, since runners may implement + // fusion by having output() directly call the body of another DoFn, potentially + // catching + // exceptions unrelated to this transform. + if (!exceptionWasThrown) { + c.output(result); + } + } + }; + tuple = + input.apply( + MapWithFailures.class.getSimpleName(), + ParDo.of(doFn) + .withOutputTags(doFn.outputTag, TupleTagList.of(doFn.failureTag)) + .withSideInputs( + ((Contextful<Fn<InputT, OutputT>>) fn).getRequirements().getSideInputs())); + + } else if (fn instanceof ProcessFunction) { + ProcessFunction<InputT, OutputT> closure = (ProcessFunction<InputT, OutputT>) fn; + doFn = + new MapWithFailuresDoFn() { + @ProcessElement + public void processElement(@Element InputT element, ProcessContext c) + throws Exception { + boolean exceptionWasThrown = false; + OutputT result; + try { + result = closure.apply(element); + } catch (Exception e) { + result = null; + exceptionWasThrown = true; + ExceptionElement<InputT> exceptionElement = ExceptionElement.of(element, e); + c.output(failureTag, exceptionHandler.apply(exceptionElement)); + } + // We make sure our output occurs outside the try block, since runners may implement + // fusion by having output() directly call the body of another DoFn, potentially + // catching + // exceptions unrelated to this transform. + if (!exceptionWasThrown) { + c.output(result); + } + } + }; + tuple = + input.apply( + MapWithFailures.class.getSimpleName(), + ParDo.of(doFn).withOutputTags(doFn.outputTag, TupleTagList.of(doFn.failureTag))); + } else { + throw new IllegalArgumentException( + String.format("Unknown type of fn class %s", fn.getClass())); + } return WithFailures.Result.of(tuple, doFn.outputTag, doFn.failureTag); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); - builder.add(DisplayData.item("class", originalFnForDisplayData.getClass())); - if (originalFnForDisplayData instanceof HasDisplayData) { - builder.include("fn", (HasDisplayData) originalFnForDisplayData); + Object fnForDisplayData; + if (fn instanceof Contextful) { + fnForDisplayData = ((Contextful<?>) fn).getClosure(); + } else { + fnForDisplayData = fn; + } + builder.add(DisplayData.item("class", fnForDisplayData.getClass())); + if (fnForDisplayData instanceof HasDisplayData) { + builder.include("fn", (HasDisplayData) fnForDisplayData); } builder.add(DisplayData.item("exceptionHandler.class", exceptionHandler.getClass())); if (exceptionHandler instanceof HasDisplayData) { @@ -313,31 +396,10 @@ public class MapElements<InputT, OutputT> } /** A DoFn implementation that handles exceptions and outputs a secondary failure collection. */ - private class MapFn extends DoFn<InputT, OutputT> { - + private abstract class MapWithFailuresDoFn extends DoFn<InputT, OutputT> { final TupleTag<OutputT> outputTag = new TupleTag<OutputT>() {}; final TupleTag<FailureT> failureTag = new FailureTag(); - @ProcessElement - public void processElement(@Element InputT element, MultiOutputReceiver r, ProcessContext c) - throws Exception { - boolean exceptionWasThrown = false; - OutputT result = null; - try { - result = fn.getClosure().apply(c.element(), Fn.Context.wrapProcessContext(c)); - } catch (Exception e) { - exceptionWasThrown = true; - ExceptionElement<InputT> exceptionElement = ExceptionElement.of(element, e); - r.get(failureTag).output(exceptionHandler.apply(exceptionElement)); - } - // We make sure our output occurs outside the try block, since runners may implement - // fusion by having output() directly call the body of another DoFn, potentially catching - // exceptions unrelated to this transform. - if (!exceptionWasThrown) { - r.get(outputTag).output(result); - } - } - @Override public void populateDisplayData(DisplayData.Builder builder) { builder.delegate(MapWithFailures.this);