Repository: beam Updated Branches: refs/heads/master 7f5753f1f -> 014614b69
Supports side inputs in MapElements and FlatMapElements Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/e2ad925d Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/e2ad925d Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/e2ad925d Branch: refs/heads/master Commit: e2ad925dc4d8bb33a264a21c48b8ceef63ac6eb3 Parents: 4b908c2 Author: Eugene Kirpichov <[email protected]> Authored: Mon Oct 2 17:36:48 2017 -0700 Committer: Eugene Kirpichov <[email protected]> Committed: Fri Oct 13 18:43:48 2017 -0700 ---------------------------------------------------------------------- .../runners/spark/SparkRunnerDebuggerTest.java | 11 +- .../org/apache/beam/sdk/io/FileBasedSink.java | 9 +- .../beam/sdk/transforms/FlatMapElements.java | 142 ++++++++----------- .../apache/beam/sdk/transforms/MapElements.java | 71 +++++----- .../beam/sdk/transforms/Requirements.java | 5 + .../apache/beam/sdk/values/TypeDescriptors.java | 1 - .../sdk/transforms/FlatMapElementsTest.java | 35 ++++- .../beam/sdk/transforms/MapElementsTest.java | 42 +++++- .../io/gcp/bigquery/DynamicDestinations.java | 9 +- 9 files changed, 196 insertions(+), 129 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/e2ad925d/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java index 246eb81..49e36ca 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java @@ -88,15 +88,14 @@ public class SparkRunnerDebuggerTest { "sparkContext.parallelize(Arrays.asList(...))\n" + "_.mapPartitions(" + "new org.apache.beam.runners.spark.examples.WordCount$ExtractWordsFn())\n" - + "_.mapPartitions(new org.apache.beam.sdk.transforms.Count$PerElement$1())\n" + + "_.mapPartitions(new org.apache.beam.sdk.transforms.Contextful())\n" + "_.combineByKey(..., new org.apache.beam.sdk.transforms.Count$CountFn(), ...)\n" + "_.groupByKey()\n" + "_.map(new org.apache.beam.sdk.transforms.Sum$SumLongFn())\n" - + "_.mapPartitions(new org.apache.beam.runners.spark" - + ".SparkRunnerDebuggerTest$PlusOne())\n" + + "_.mapPartitions(new org.apache.beam.sdk.transforms.Contextful())\n" + "sparkContext.union(...)\n" + "_.mapPartitions(" - + "new org.apache.beam.runners.spark.examples.WordCount$FormatAsTextFn())\n" + + "new org.apache.beam.sdk.transforms.Contextful())\n" + "_.<org.apache.beam.sdk.io.TextIO$Write>"; SparkRunnerDebugger.DebugSparkPipelineResult result = @@ -141,11 +140,11 @@ public class SparkRunnerDebuggerTest { + "_.map(new org.apache.beam.sdk.transforms.windowing.FixedWindows())\n" + "_.mapPartitions(new org.apache.beam.runners.spark." + "SparkRunnerDebuggerTest$FormatKVFn())\n" - + "_.mapPartitions(new org.apache.beam.sdk.transforms.Distinct$2())\n" + + "_.mapPartitions(new org.apache.beam.sdk.transforms.Contextful())\n" + "_.groupByKey()\n" + "_.map(new org.apache.beam.sdk.transforms.Combine$IterableCombineFn())\n" + "_.mapPartitions(new org.apache.beam.sdk.transforms.Distinct$3())\n" - + "_.mapPartitions(new org.apache.beam.sdk.transforms.WithKeys$2())\n" + + "_.mapPartitions(new org.apache.beam.sdk.transforms.Contextful())\n" + "_.<org.apache.beam.sdk.io.kafka.AutoValue_KafkaIO_Write>"; SparkRunnerDebugger.DebugSparkPipelineResult result = http://git-wip-us.apache.org/repos/asf/beam/blob/e2ad925d/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java index 9834e6e..d577fea 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java @@ -319,7 +319,14 @@ public abstract class FileBasedSink<UserT, DestinationT, OutputT> DynamicDestinations.class, new TypeVariableExtractor< DynamicDestinations<UserT, DestinationT, OutputT>, DestinationT>() {}); - return registry.getCoder(descriptor); + try { + return registry.getCoder(descriptor); + } catch (CannotProvideCoderException e) { + throw new CannotProvideCoderException( + "Failed to infer coder for DestinationT from type " + + descriptor + ", please provide it explicitly by overriding getDestinationCoder()", + e); + } } } http://git-wip-us.apache.org/repos/asf/beam/blob/e2ad925d/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java index a8a94f9..97e1dfb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java @@ -17,11 +17,13 @@ */ package org.apache.beam.sdk.transforms; -import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkArgument; -import java.lang.reflect.ParameterizedType; import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.transforms.Contextful.Fn; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.display.HasDisplayData; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptors; @@ -32,30 +34,20 @@ import org.apache.beam.sdk.values.TypeDescriptors; */ public class FlatMapElements<InputT, OutputT> extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { - /** - * Temporarily stores the argument of {@link #into(TypeDescriptor)} until combined with the - * argument of {@link #via(SerializableFunction)} into the fully-specified {@link #fn}. Stays null - * if constructed using {@link #via(SimpleFunction)} directly. - */ - @Nullable - private final transient TypeDescriptor<Iterable<OutputT>> outputType; - - /** - * Non-null on a fully specified transform - is null only when constructed using {@link - * #into(TypeDescriptor)}, until the fn is specified using {@link #via(SerializableFunction)}. - */ - @Nullable - private final SimpleFunction<InputT, Iterable<OutputT>> fn; - private final DisplayData.ItemSpec<?> fnClassDisplayData; + private final transient TypeDescriptor<InputT> inputType; + private final transient TypeDescriptor<OutputT> outputType; + @Nullable private final transient Object originalFnForDisplayData; + @Nullable private final Contextful<Fn<InputT, Iterable<OutputT>>> fn; private FlatMapElements( - @Nullable SimpleFunction<InputT, Iterable<OutputT>> fn, - @Nullable TypeDescriptor<Iterable<OutputT>> outputType, - @Nullable Class<?> fnClass) { + @Nullable Contextful<Fn<InputT, Iterable<OutputT>>> fn, + @Nullable Object originalFnForDisplayData, + TypeDescriptor<InputT> inputType, + TypeDescriptor<OutputT> outputType) { this.fn = fn; + this.originalFnForDisplayData = originalFnForDisplayData; + this.inputType = inputType; this.outputType = outputType; - this.fnClassDisplayData = DisplayData.item("flatMapFn", fnClass).withLabel("FlatMap Function"); - } /** @@ -82,7 +74,14 @@ extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { */ public static <InputT, OutputT> FlatMapElements<InputT, OutputT> via(SimpleFunction<? super InputT, ? extends Iterable<OutputT>> fn) { - return new FlatMapElements(fn, null, fn.getClass()); + Contextful<Fn<InputT, Iterable<OutputT>>> wrapped = (Contextful) Contextful.fn(fn); + TypeDescriptor<OutputT> outputType = + TypeDescriptors.extractFromTypeParameters( + (TypeDescriptor<Iterable<OutputT>>) fn.getOutputTypeDescriptor(), + Iterable.class, + new TypeDescriptors.TypeVariableExtractor<Iterable<OutputT>, OutputT>() {}); + TypeDescriptor<InputT> inputType = (TypeDescriptor<InputT>) fn.getInputTypeDescriptor(); + return new FlatMapElements<>(wrapped, fn, inputType, outputType); } /** @@ -91,7 +90,7 @@ extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { */ public static <OutputT> FlatMapElements<?, OutputT> into(final TypeDescriptor<OutputT> outputType) { - return new FlatMapElements<>(null, TypeDescriptors.iterables(outputType), null); + return new FlatMapElements<>(null, null, null, outputType); } /** @@ -112,73 +111,58 @@ extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { */ public <NewInputT> FlatMapElements<NewInputT, OutputT> via(SerializableFunction<NewInputT, ? extends Iterable<OutputT>> fn) { - return new FlatMapElements( - SimpleFunction.fromSerializableFunctionWithOutputType(fn, (TypeDescriptor) outputType), - null, - fn.getClass()); + return new FlatMapElements<>( + (Contextful) Contextful.fn(fn), fn, TypeDescriptors.inputOf(fn), outputType); + } + + /** Like {@link #via(SerializableFunction)}, but allows access to additional context. */ + @Experimental(Experimental.Kind.CONTEXTFUL) + public <NewInputT> FlatMapElements<NewInputT, OutputT> via( + Contextful<Fn<NewInputT, Iterable<OutputT>>> fn) { + return new FlatMapElements<>( + fn, fn.getClosure(), TypeDescriptors.inputOf(fn.getClosure()), outputType); } @Override public PCollection<OutputT> expand(PCollection<? extends InputT> input) { - checkNotNull(fn, "Must specify a function on FlatMapElements using .via()"); + checkArgument(fn != null, ".via() is required"); return input.apply( "FlatMap", ParDo.of( - new DoFn<InputT, OutputT>() { - private static final long serialVersionUID = 0L; - - @ProcessElement - public void processElement(ProcessContext c) { - for (OutputT element : fn.apply(c.element())) { - c.output(element); - } - } - - @Override - public TypeDescriptor<InputT> getInputTypeDescriptor() { - return fn.getInputTypeDescriptor(); - } - - @Override - public TypeDescriptor<OutputT> getOutputTypeDescriptor() { - @SuppressWarnings({"rawtypes", "unchecked"}) // safe by static typing - TypeDescriptor<Iterable<?>> iterableType = - (TypeDescriptor) fn.getOutputTypeDescriptor(); - - @SuppressWarnings("unchecked") // safe by correctness of getIterableElementType - TypeDescriptor<OutputT> outputType = - (TypeDescriptor<OutputT>) getIterableElementType(iterableType); - - return outputType; - } - })); + new DoFn<InputT, OutputT>() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + Iterable<OutputT> res = + fn.getClosure().apply(c.element(), Fn.Context.wrapProcessContext(c)); + for (OutputT output : res) { + c.output(output); + } + } + + @Override + public TypeDescriptor<InputT> getInputTypeDescriptor() { + return inputType; + } + + @Override + public TypeDescriptor<OutputT> getOutputTypeDescriptor() { + return outputType; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.delegate(FlatMapElements.this); + } + }) + .withSideInputs(fn.getRequirements().getSideInputs())); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); - builder - .include("flatMapFn", fn) - .add(fnClassDisplayData); - } - - /** - * Does a best-effort job of getting the best {@link TypeDescriptor} for the type of the - * elements contained in the iterable described by the given {@link TypeDescriptor}. - */ - private static TypeDescriptor<?> getIterableElementType( - TypeDescriptor<Iterable<?>> iterableTypeDescriptor) { - - // If a rawtype was used, the type token may be for Object, not a subtype of Iterable. - // In this case, we rely on static typing of the function elsewhere to ensure it is - // at least some kind of iterable, and grossly overapproximate the element type to be Object. - if (!iterableTypeDescriptor.isSubtypeOf(new TypeDescriptor<Iterable<?>>() {})) { - return new TypeDescriptor<Object>() {}; + builder.add(DisplayData.item("class", originalFnForDisplayData.getClass())); + if (originalFnForDisplayData instanceof HasDisplayData) { + builder.include("fn", (HasDisplayData) originalFnForDisplayData); } - - // Otherwise we can do the proper thing and get the actual type parameter. - ParameterizedType iterableType = - (ParameterizedType) iterableTypeDescriptor.getSupertype(Iterable.class).getType(); - return TypeDescriptor.of(iterableType.getActualTypeArguments()[0]); } } http://git-wip-us.apache.org/repos/asf/beam/blob/e2ad925d/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java index 792a6d5..1d259ac 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java @@ -20,36 +20,34 @@ package org.apache.beam.sdk.transforms; import static com.google.common.base.Preconditions.checkNotNull; import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.transforms.Contextful.Fn; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.display.HasDisplayData; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; /** * {@code PTransform}s for mapping a simple function over the elements of a {@link PCollection}. */ public class MapElements<InputT, OutputT> extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { - /** - * Temporarily stores the argument of {@link #into(TypeDescriptor)} until combined with the - * argument of {@link #via(SerializableFunction)} into the fully-specified {@link #fn}. Stays null - * if constructed using {@link #via(SimpleFunction)} directly. - */ - @Nullable private final transient TypeDescriptor<OutputT> outputType; - - /** - * Non-null on a fully specified transform - is null only when constructed using {@link - * #into(TypeDescriptor)}, until the fn is specified using {@link #via(SerializableFunction)}. - */ - @Nullable private final SimpleFunction<InputT, OutputT> fn; - private final DisplayData.ItemSpec<?> fnClassDisplayData; + private final transient TypeDescriptor<InputT> inputType; + private final transient TypeDescriptor<OutputT> outputType; + @Nullable private final transient Object originalFnForDisplayData; + @Nullable private final Contextful<Fn<InputT, OutputT>> fn; private MapElements( - @Nullable SimpleFunction<InputT, OutputT> fn, - @Nullable TypeDescriptor<OutputT> outputType, - @Nullable Class<?> fnClass) { + @Nullable Contextful<Fn<InputT, OutputT>> fn, + @Nullable Object originalFnForDisplayData, + TypeDescriptor<InputT> inputType, + TypeDescriptor<OutputT> outputType) { this.fn = fn; + this.originalFnForDisplayData = originalFnForDisplayData; + this.inputType = inputType; this.outputType = outputType; - this.fnClassDisplayData = DisplayData.item("mapFn", fnClass).withLabel("Map Function"); } /** @@ -57,10 +55,11 @@ extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { * takes an input {@code PCollection<InputT>} and returns a {@code PCollection<OutputT>} * containing {@code fn.apply(v)} for every element {@code v} in the input. * - * <p>This overload is intended primarily for use in Java 7. In Java 8, the overload - * {@link #via(SerializableFunction)} supports use of lambda for greater concision. + * <p>This overload is intended primarily for use in Java 7. In Java 8, the overload {@link + * #via(SerializableFunction)} supports use of lambda for greater concision. * * <p>Example of use in Java 7: + * * <pre>{@code * PCollection<String> words = ...; * PCollection<Integer> wordsPerLine = words.apply(MapElements.via( @@ -73,7 +72,8 @@ extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { */ public static <InputT, OutputT> MapElements<InputT, OutputT> via( final SimpleFunction<InputT, OutputT> fn) { - return new MapElements<>(fn, null, fn.getClass()); + return new MapElements<>( + Contextful.fn(fn), fn, fn.getInputTypeDescriptor(), fn.getOutputTypeDescriptor()); } /** @@ -82,7 +82,7 @@ extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { */ public static <OutputT> MapElements<?, OutputT> into(final TypeDescriptor<OutputT> outputType) { - return new MapElements<>(null, outputType, null); + return new MapElements<>(null, null, null, outputType); } /** @@ -104,10 +104,16 @@ extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { */ public <NewInputT> MapElements<NewInputT, OutputT> via( SerializableFunction<NewInputT, OutputT> fn) { + return new MapElements<>(Contextful.fn(fn), fn, TypeDescriptors.inputOf(fn), outputType); + } + + /** + * Like {@link #via(SerializableFunction)}, but supports access to context, such as side inputs. + */ + @Experimental(Kind.CONTEXTFUL) + public <NewInputT> MapElements<NewInputT, OutputT> via(Contextful<Fn<NewInputT, OutputT>> fn) { return new MapElements<>( - SimpleFunction.fromSerializableFunctionWithOutputType(fn, outputType), - null, - fn.getClass()); + fn, fn.getClosure(), TypeDescriptors.inputOf(fn.getClosure()), outputType); } @Override @@ -118,8 +124,8 @@ extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { ParDo.of( new DoFn<InputT, OutputT>() { @ProcessElement - public void processElement(ProcessContext c) { - c.output(fn.apply(c.element())); + public void processElement(ProcessContext c) throws Exception { + c.output(fn.getClosure().apply(c.element(), Fn.Context.wrapProcessContext(c))); } @Override @@ -129,21 +135,22 @@ extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { @Override public TypeDescriptor<InputT> getInputTypeDescriptor() { - return fn.getInputTypeDescriptor(); + return inputType; } @Override public TypeDescriptor<OutputT> getOutputTypeDescriptor() { - return fn.getOutputTypeDescriptor(); + return outputType; } - })); + }).withSideInputs(fn.getRequirements().getSideInputs())); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); - builder - .include("mapFn", fn) - .add(fnClassDisplayData); + builder.add(DisplayData.item("class", originalFnForDisplayData.getClass())); + if (originalFnForDisplayData instanceof HasDisplayData) { + builder.include("fn", (HasDisplayData) originalFnForDisplayData); + } } } http://git-wip-us.apache.org/repos/asf/beam/blob/e2ad925d/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Requirements.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Requirements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Requirements.java index acc409f..f90e8f3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Requirements.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Requirements.java @@ -53,4 +53,9 @@ public final class Requirements implements Serializable { public static Requirements empty() { return new Requirements(Collections.<PCollectionView<?>>emptyList()); } + + /** Whether this is an empty set of requirements. */ + public boolean isEmpty() { + return sideInputs.isEmpty(); + } } http://git-wip-us.apache.org/repos/asf/beam/blob/e2ad925d/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptors.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptors.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptors.java index 29a2496..e59f84b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptors.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptors.java @@ -23,7 +23,6 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.util.List; import java.util.Set; -import javax.annotation.Nullable; import org.apache.beam.sdk.transforms.Contextful; import org.apache.beam.sdk.transforms.SerializableFunction; http://git-wip-us.apache.org/repos/asf/beam/blob/e2ad925d/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java index 11f284f..68ceafb 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java @@ -17,7 +17,10 @@ */ package org.apache.beam.sdk.transforms; +import static org.apache.beam.sdk.transforms.Contextful.fn; +import static org.apache.beam.sdk.transforms.Requirements.requiresSideInputs; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static org.apache.beam.sdk.values.TypeDescriptors.integers; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertThat; @@ -30,9 +33,11 @@ import java.util.Set; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Contextful.Fn; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TypeDescriptor; import org.junit.Rule; import org.junit.Test; @@ -77,6 +82,32 @@ public class FlatMapElementsTest implements Serializable { } /** + * Basic test of {@link FlatMapElements} with a {@link Fn} and a side input. + */ + @Test + @Category(NeedsRunner.class) + public void testFlatMapBasicWithSideInput() throws Exception { + final PCollectionView<Integer> view = + pipeline.apply("Create base", Create.of(40)).apply(View.<Integer>asSingleton()); + PCollection<Integer> output = + pipeline + .apply(Create.of(0, 1, 2)) + .apply( + FlatMapElements.into(integers()).via(fn( + new Fn<Integer, Iterable<Integer>>() { + @Override + public List<Integer> apply(Integer input, Context c) { + return ImmutableList.of( + c.sideInput(view) - input, c.sideInput(view) + input); + } + }, + requiresSideInputs(view)))); + + PAssert.that(output).containsInAnyOrder(38, 39, 40, 40, 41, 42); + pipeline.run(); + } + + /** * Tests that when built with a concrete subclass of {@link SimpleFunction}, the type descriptor * of the output reflects its static type. */ @@ -144,7 +175,7 @@ public class FlatMapElementsTest implements Serializable { }; FlatMapElements<?, ?> simpleMap = FlatMapElements.via(simpleFn); - assertThat(DisplayData.from(simpleMap), hasDisplayItem("flatMapFn", simpleFn.getClass())); + assertThat(DisplayData.from(simpleMap), hasDisplayItem("class", simpleFn.getClass())); } @Test @@ -162,7 +193,7 @@ public class FlatMapElementsTest implements Serializable { }; FlatMapElements<?, ?> simpleFlatMap = FlatMapElements.via(simpleFn); - assertThat(DisplayData.from(simpleFlatMap), hasDisplayItem("flatMapFn", simpleFn.getClass())); + assertThat(DisplayData.from(simpleFlatMap), hasDisplayItem("class", simpleFn.getClass())); assertThat(DisplayData.from(simpleFlatMap), hasDisplayItem("foo", "baz")); } http://git-wip-us.apache.org/repos/asf/beam/blob/e2ad925d/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java index 241b60e..2c24f10 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java @@ -17,7 +17,10 @@ */ package org.apache.beam.sdk.transforms; +import static org.apache.beam.sdk.transforms.Contextful.fn; +import static org.apache.beam.sdk.transforms.Requirements.requiresSideInputs; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static org.apache.beam.sdk.values.TypeDescriptors.integers; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.junit.Assert.assertThat; @@ -28,12 +31,13 @@ import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.ValidatesRunner; +import org.apache.beam.sdk.transforms.Contextful.Fn; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.DisplayDataEvaluator; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.beam.sdk.values.TypeDescriptors; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -96,6 +100,30 @@ public class MapElementsTest implements Serializable { } /** + * Basic test of {@link MapElements} with a {@link Fn} and a side input. + */ + @Test + @Category(NeedsRunner.class) + public void testMapBasicWithSideInput() throws Exception { + final PCollectionView<Integer> view = + pipeline.apply("Create base", Create.of(40)).apply(View.<Integer>asSingleton()); + PCollection<Integer> output = + pipeline + .apply(Create.of(0, 1, 2)) + .apply(MapElements.into(integers()) + .via(fn(new Fn<Integer, Integer>() { + @Override + public Integer apply(Integer element, Context c) { + return element + c.sideInput(view); + } + }, + requiresSideInputs(view)))); + + PAssert.that(output).containsInAnyOrder(40, 41, 42); + pipeline.run(); + } + + /** * Basic test of {@link MapElements} coder propagation with a parametric {@link SimpleFunction}. */ @Test @@ -157,7 +185,7 @@ public class MapElementsTest implements Serializable { pipeline .apply(Create.of(1, 2, 3)) .apply( - MapElements.into(TypeDescriptors.integers()) + MapElements.into(integers()) .via( new SerializableFunction<Integer, Integer>() { @Override @@ -216,9 +244,9 @@ public class MapElementsTest implements Serializable { }; MapElements<?, ?> serializableMap = - MapElements.into(TypeDescriptors.integers()).via(serializableFn); + MapElements.into(integers()).via(serializableFn); assertThat(DisplayData.from(serializableMap), - hasDisplayItem("mapFn", serializableFn.getClass())); + hasDisplayItem("class", serializableFn.getClass())); } @Test @@ -231,7 +259,7 @@ public class MapElementsTest implements Serializable { }; MapElements<?, ?> simpleMap = MapElements.via(simpleFn); - assertThat(DisplayData.from(simpleMap), hasDisplayItem("mapFn", simpleFn.getClass())); + assertThat(DisplayData.from(simpleMap), hasDisplayItem("class", simpleFn.getClass())); } @Test public void testSimpleFunctionDisplayData() { @@ -250,7 +278,7 @@ public class MapElementsTest implements Serializable { MapElements<?, ?> simpleMap = MapElements.via(simpleFn); - assertThat(DisplayData.from(simpleMap), hasDisplayItem("mapFn", simpleFn.getClass())); + assertThat(DisplayData.from(simpleMap), hasDisplayItem("class", simpleFn.getClass())); assertThat(DisplayData.from(simpleMap), hasDisplayItem("foo", "baz")); } @@ -269,7 +297,7 @@ public class MapElementsTest implements Serializable { Set<DisplayData> displayData = evaluator.<Integer>displayDataForPrimitiveTransforms(map); assertThat("MapElements should include the mapFn in its primitive display data", - displayData, hasItem(hasDisplayItem("mapFn", mapFn.getClass()))); + displayData, hasItem(hasDisplayItem("class", mapFn.getClass()))); } static class VoidValues<K, V> http://git-wip-us.apache.org/repos/asf/beam/blob/e2ad925d/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinations.java ---------------------------------------------------------------------- diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinations.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinations.java index ecfc990..e351138 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinations.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinations.java @@ -164,6 +164,13 @@ public abstract class DynamicDestinations<T, DestinationT> implements Serializab DynamicDestinations.class, new TypeDescriptors.TypeVariableExtractor< DynamicDestinations<T, DestinationT>, DestinationT>() {}); - return registry.getCoder(descriptor); + try { + return registry.getCoder(descriptor); + } catch (CannotProvideCoderException e) { + throw new CannotProvideCoderException( + "Failed to infer coder for DestinationT from type " + + descriptor + ", please provide it explicitly by overriding getDestinationCoder()", + e); + } } }
