Repository: incubator-beam Updated Branches: refs/heads/master 8daf518bc -> 2b5c6bcb2
Use input type in coder inference for MapElements and FlatMapElements Previously, the input TypeDescriptor was unknown, so we would fail to infer a coder for things like MapElements.of(SimpleFunction<T, T>) even if the input PCollection provided a coder for T. Now, the input type is plumbed appropriately and the coder is inferred. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/4ac5cafe Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/4ac5cafe Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/4ac5cafe Branch: refs/heads/master Commit: 4ac5cafe90a371cf616f97cb202d5016b68616d1 Parents: 8daf518 Author: Kenneth Knowles <[email protected]> Authored: Fri Jul 29 10:35:01 2016 -0700 Committer: Kenneth Knowles <[email protected]> Committed: Thu Aug 4 20:18:59 2016 -0700 ---------------------------------------------------------------------- .../beam/sdk/transforms/FlatMapElements.java | 126 +++++++++++++------ .../apache/beam/sdk/transforms/MapElements.java | 60 +++++---- .../beam/sdk/transforms/SimpleFunction.java | 34 +++++ .../sdk/transforms/FlatMapElementsTest.java | 48 +++++++ .../beam/sdk/transforms/MapElementsTest.java | 84 +++++++++++++ 5 files changed, 288 insertions(+), 64 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java index 694592e..04d993c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java @@ -17,8 +17,10 @@ */ package org.apache.beam.sdk.transforms; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; import java.lang.reflect.ParameterizedType; @@ -45,8 +47,16 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>> { * descriptor need not be provided. */ public static <InputT, OutputT> MissingOutputTypeDescriptor<InputT, OutputT> - via(SerializableFunction<InputT, ? extends Iterable<OutputT>> fn) { - return new MissingOutputTypeDescriptor<>(fn); + via(SerializableFunction<? super InputT, ? extends Iterable<OutputT>> fn) { + + // TypeDescriptor interacts poorly with the wildcards needed to correctly express + // covariance and contravariance in Java, so instead we cast it to an invariant + // function here. + @SuppressWarnings("unchecked") // safe covariant cast + SerializableFunction<InputT, Iterable<OutputT>> simplerFn = + (SerializableFunction<InputT, Iterable<OutputT>>) fn; + + return new MissingOutputTypeDescriptor<>(simplerFn); } /** @@ -72,16 +82,15 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>> { * <p>To use a Java 8 lambda, see {@link #via(SerializableFunction)}. */ public static <InputT, OutputT> FlatMapElements<InputT, OutputT> - via(SimpleFunction<InputT, ? extends Iterable<OutputT>> fn) { - - @SuppressWarnings({"rawtypes", "unchecked"}) // safe by static typing - TypeDescriptor<Iterable<?>> iterableType = (TypeDescriptor) fn.getOutputTypeDescriptor(); - - @SuppressWarnings("unchecked") // safe by correctness of getIterableElementType - TypeDescriptor<OutputT> outputType = - (TypeDescriptor<OutputT>) getIterableElementType(iterableType); - - return new FlatMapElements<>(fn, outputType); + via(SimpleFunction<? super InputT, ? extends Iterable<OutputT>> fn) { + // TypeDescriptor interacts poorly with the wildcards needed to correctly express + // covariance and contravariance in Java, so instead we cast it to an invariant + // function here. + @SuppressWarnings("unchecked") // safe covariant cast + SimpleFunction<InputT, Iterable<OutputT>> simplerFn = + (SimpleFunction<InputT, Iterable<OutputT>>) fn; + + return new FlatMapElements<>(simplerFn, fn.getClass()); } /** @@ -91,18 +100,80 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>> { */ public static final class MissingOutputTypeDescriptor<InputT, OutputT> { - private final SerializableFunction<InputT, ? extends Iterable<OutputT>> fn; + private final SerializableFunction<InputT, Iterable<OutputT>> fn; private MissingOutputTypeDescriptor( - SerializableFunction<InputT, ? extends Iterable<OutputT>> fn) { + SerializableFunction<InputT, Iterable<OutputT>> fn) { this.fn = fn; } public FlatMapElements<InputT, OutputT> withOutputType(TypeDescriptor<OutputT> outputType) { - return new FlatMapElements<>(fn, outputType); + TypeDescriptor<Iterable<OutputT>> iterableOutputType = TypeDescriptors.iterables(outputType); + + return new FlatMapElements<>( + SimpleFunction.fromSerializableFunctionWithOutputType(fn, + iterableOutputType), + fn.getClass()); } } + ////////////////////////////////////////////////////////////////////////////////////////////////// + + private final SimpleFunction<InputT, ? extends Iterable<OutputT>> fn; + private final DisplayData.Item<?> fnClassDisplayData; + + private FlatMapElements( + SimpleFunction<InputT, ? extends Iterable<OutputT>> fn, + Class<?> fnClass) { + this.fn = fn; + this.fnClassDisplayData = DisplayData.item("flatMapFn", fnClass).withLabel("FlatMap Function"); + } + + @Override + public PCollection<OutputT> apply(PCollection<InputT> input) { + return input.apply( + "FlatMap", + ParDo.of( + new DoFn<InputT, OutputT>() { + private static final long serialVersionUID = 0L; + + @ProcessElement + public void processElement(ProcessContext c) { + for (OutputT element : fn.apply(c.element())) { + c.output(element); + } + } + + @Override + public TypeDescriptor<InputT> getInputTypeDescriptor() { + return fn.getInputTypeDescriptor(); + } + + @Override + public TypeDescriptor<OutputT> getOutputTypeDescriptor() { + @SuppressWarnings({"rawtypes", "unchecked"}) // safe by static typing + TypeDescriptor<Iterable<?>> iterableType = + (TypeDescriptor) fn.getOutputTypeDescriptor(); + + @SuppressWarnings("unchecked") // safe by correctness of getIterableElementType + TypeDescriptor<OutputT> outputType = + (TypeDescriptor<OutputT>) getIterableElementType(iterableType); + + return outputType; + } + })); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder.add(fnClassDisplayData); + } + + /** + * Does a best-effort job of getting the best {@link TypeDescriptor} for the type of the + * elements contained in the iterable described by the given {@link TypeDescriptor}. + */ private static TypeDescriptor<?> getIterableElementType( TypeDescriptor<Iterable<?>> iterableTypeDescriptor) { @@ -118,29 +189,4 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>> { (ParameterizedType) iterableTypeDescriptor.getSupertype(Iterable.class).getType(); return TypeDescriptor.of(iterableType.getActualTypeArguments()[0]); } - - ////////////////////////////////////////////////////////////////////////////////////////////////// - - private final SerializableFunction<InputT, ? extends Iterable<OutputT>> fn; - private final transient TypeDescriptor<OutputT> outputType; - - private FlatMapElements( - SerializableFunction<InputT, ? extends Iterable<OutputT>> fn, - TypeDescriptor<OutputT> outputType) { - this.fn = fn; - this.outputType = outputType; - } - - @Override - public PCollection<OutputT> apply(PCollection<InputT> input) { - return input.apply("Map", ParDo.of(new DoFn<InputT, OutputT>() { - private static final long serialVersionUID = 0L; - @ProcessElement - public void processElement(ProcessContext c) { - for (OutputT element : fn.apply(c.element())) { - c.output(element); - } - } - })).setTypeDescriptorInternal(outputType); - } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java index b7b9a5f..429d3fc 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java @@ -67,9 +67,9 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>> { * })); * }</pre> */ - public static <InputT, OutputT> MapElements<InputT, OutputT> - via(final SimpleFunction<InputT, OutputT> fn) { - return new MapElements<>(fn, fn.getOutputTypeDescriptor()); + public static <InputT, OutputT> MapElements<InputT, OutputT> via( + final SimpleFunction<InputT, OutputT> fn) { + return new MapElements<>(fn, fn.getClass()); } /** @@ -85,42 +85,54 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>> { this.fn = fn; } - public MapElements<InputT, OutputT> withOutputType(TypeDescriptor<OutputT> outputType) { - return new MapElements<>(fn, outputType); + public MapElements<InputT, OutputT> withOutputType(final TypeDescriptor<OutputT> outputType) { + return new MapElements<>( + SimpleFunction.fromSerializableFunctionWithOutputType(fn, outputType), fn.getClass()); } + } /////////////////////////////////////////////////////////////////// - private final SerializableFunction<InputT, OutputT> fn; - private final transient TypeDescriptor<OutputT> outputType; + private final SimpleFunction<InputT, OutputT> fn; + private final DisplayData.Item<?> fnClassDisplayData; - private MapElements( - SerializableFunction<InputT, OutputT> fn, - TypeDescriptor<OutputT> outputType) { + private MapElements(SimpleFunction<InputT, OutputT> fn, Class<?> fnClass) { this.fn = fn; - this.outputType = outputType; + this.fnClassDisplayData = DisplayData.item("mapFn", fnClass).withLabel("Map Function"); } @Override public PCollection<OutputT> apply(PCollection<InputT> input) { - return input.apply("Map", ParDo.of(new DoFn<InputT, OutputT>() { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(fn.apply(c.element())); - } - - @Override - public void populateDisplayData(DisplayData.Builder builder) { - MapElements.this.populateDisplayData(builder); - } - })).setTypeDescriptorInternal(outputType); + return input.apply( + "Map", + ParDo.of( + new DoFn<InputT, OutputT>() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(fn.apply(c.element())); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + MapElements.this.populateDisplayData(builder); + } + + @Override + public TypeDescriptor<InputT> getInputTypeDescriptor() { + return fn.getInputTypeDescriptor(); + } + + @Override + public TypeDescriptor<OutputT> getOutputTypeDescriptor() { + return fn.getOutputTypeDescriptor(); + } + })); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); - builder.add(DisplayData.item("mapFn", fn.getClass()) - .withLabel("Map Function")); + builder.add(fnClassDisplayData); } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java index 8894352..6c540cc 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java @@ -27,6 +27,12 @@ import org.apache.beam.sdk.values.TypeDescriptor; public abstract class SimpleFunction<InputT, OutputT> implements SerializableFunction<InputT, OutputT> { + public static <InputT, OutputT> + SimpleFunction<InputT, OutputT> fromSerializableFunctionWithOutputType( + SerializableFunction<InputT, OutputT> fn, TypeDescriptor<OutputT> outputType) { + return new SimpleFunctionWithOutputType<>(fn, outputType); + } + /** * Returns a {@link TypeDescriptor} capturing what is known statically * about the input type of this {@code OldDoFn} instance's most-derived @@ -52,4 +58,32 @@ public abstract class SimpleFunction<InputT, OutputT> public TypeDescriptor<OutputT> getOutputTypeDescriptor() { return new TypeDescriptor<OutputT>(this) {}; } + + /** + * A {@link SimpleFunction} built from a {@link SerializableFunction}, having + * a known output type that is explicitly set. + */ + private static class SimpleFunctionWithOutputType<InputT, OutputT> + extends SimpleFunction<InputT, OutputT> { + + private final SerializableFunction<InputT, OutputT> fn; + private final TypeDescriptor<OutputT> outputType; + + public SimpleFunctionWithOutputType( + SerializableFunction<InputT, OutputT> fn, + TypeDescriptor<OutputT> outputType) { + this.fn = fn; + this.outputType = outputType; + } + + @Override + public OutputT apply(InputT input) { + return fn.apply(input); + } + + @Override + public TypeDescriptor<OutputT> getOutputTypeDescriptor() { + return outputType; + } + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java index 057fd19..781e143 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.transforms; +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; + import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertThat; @@ -24,6 +26,7 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; @@ -102,6 +105,51 @@ public class FlatMapElementsTest implements Serializable { pipeline.run(); } + /** + * A {@link SimpleFunction} to test that the coder registry can propagate coders + * that are bound to type variables. + */ + private static class PolymorphicSimpleFunction<T> extends SimpleFunction<T, Iterable<T>> { + @Override + public Iterable<T> apply(T input) { + return Collections.<T>emptyList(); + } + } + + /** + * Basic test of {@link MapElements} coder propagation with a parametric {@link SimpleFunction}. + */ + @Test + public void testPolymorphicSimpleFunction() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection<Integer> output = pipeline + .apply(Create.of(1, 2, 3)) + + // This is the function that needs to propagate the input T to output T + .apply("Polymorphic Identity", MapElements.via(new PolymorphicSimpleFunction<Integer>())) + + // This is a consumer to ensure that all coder inference logic is executed. + .apply("Test Consumer", MapElements.via(new SimpleFunction<Iterable<Integer>, Integer>() { + @Override + public Integer apply(Iterable<Integer> input) { + return 42; + } + })); + } + + @Test + public void testSimpleFunctionClassDisplayData() { + SimpleFunction<Integer, List<Integer>> simpleFn = new SimpleFunction<Integer, List<Integer>>() { + @Override + public List<Integer> apply(Integer input) { + return Collections.emptyList(); + } + }; + + FlatMapElements<?, ?> simpleMap = FlatMapElements.via(simpleFn); + assertThat(DisplayData.from(simpleMap), hasDisplayItem("flatMapFn", simpleFn.getClass())); + } + @Test @Category(NeedsRunner.class) public void testVoidValues() throws Exception { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java index b4751d2..dbf8844 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java @@ -54,6 +54,29 @@ public class MapElementsTest implements Serializable { public transient ExpectedException thrown = ExpectedException.none(); /** + * A {@link SimpleFunction} to test that the coder registry can propagate coders + * that are bound to type variables. + */ + private static class PolymorphicSimpleFunction<T> extends SimpleFunction<T, T> { + @Override + public T apply(T input) { + return input; + } + } + + /** + * A {@link SimpleFunction} to test that the coder registry can propagate coders + * that are bound to type variables, when the variable appears nested in the + * output. + */ + private static class NestedPolymorphicSimpleFunction<T> extends SimpleFunction<T, KV<T, String>> { + @Override + public KV<T, String> apply(T input) { + return KV.of(input, "hello"); + } + } + + /** * Basic test of {@link MapElements} with a {@link SimpleFunction}. */ @Test @@ -74,6 +97,55 @@ public class MapElementsTest implements Serializable { } /** + * Basic test of {@link MapElements} coder propagation with a parametric {@link SimpleFunction}. + */ + @Test + public void testPolymorphicSimpleFunction() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection<Integer> output = pipeline + .apply(Create.of(1, 2, 3)) + + // This is the function that needs to propagate the input T to output T + .apply("Polymorphic Identity", MapElements.via(new PolymorphicSimpleFunction<Integer>())) + + // This is a consumer to ensure that all coder inference logic is executed. + .apply("Test Consumer", MapElements.via(new SimpleFunction<Integer, Integer>() { + @Override + public Integer apply(Integer input) { + return input; + } + })); + } + + /** + * Test of {@link MapElements} coder propagation with a parametric {@link SimpleFunction} + * where the type variable occurs nested within other concrete type constructors. + */ + @Test + public void testNestedPolymorphicSimpleFunction() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection<Integer> output = + pipeline + .apply(Create.of(1, 2, 3)) + + // This is the function that needs to propagate the input T to output T + .apply( + "Polymorphic Identity", + MapElements.via(new NestedPolymorphicSimpleFunction<Integer>())) + + // This is a consumer to ensure that all coder inference logic is executed. + .apply( + "Test Consumer", + MapElements.via( + new SimpleFunction<KV<Integer, String>, Integer>() { + @Override + public Integer apply(KV<Integer, String> input) { + return 42; + } + })); + } + + /** * Basic test of {@link MapElements} with a {@link SerializableFunction}. This style is * generally discouraged in Java 7, in favor of {@link SimpleFunction}. */ @@ -148,6 +220,18 @@ public class MapElementsTest implements Serializable { } @Test + public void testSimpleFunctionClassDisplayData() { + SimpleFunction<?, ?> simpleFn = new SimpleFunction<Integer, Integer>() { + @Override + public Integer apply(Integer input) { + return input; + } + }; + + MapElements<?, ?> simpleMap = MapElements.via(simpleFn); + assertThat(DisplayData.from(simpleMap), hasDisplayItem("mapFn", simpleFn.getClass())); + } + @Test public void testSimpleFunctionDisplayData() { SimpleFunction<?, ?> simpleFn = new SimpleFunction<Integer, Integer>() { @Override
