This is an automated email from the ASF dual-hosted git repository. lcwik pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new c5c0cce [BEAM-2963] Remove layer of indirection in output name mapping in Dataflow simplifying what needs to be passed for all portable pipelines to the Java SDK harness. (#4460) c5c0cce is described below commit c5c0ccea952f6d200e4b8d8409542d2c6bb7aa50 Author: Lukasz Cwik <lc...@google.com> AuthorDate: Mon Jan 22 14:59:35 2018 -0800 [BEAM-2963] Remove layer of indirection in output name mapping in Dataflow simplifying what needs to be passed for all portable pipelines to the Java SDK harness. (#4460) --- runners/google-cloud-dataflow-java/build.gradle | 2 +- runners/google-cloud-dataflow-java/pom.xml | 2 +- .../dataflow/DataflowPipelineTranslator.java | 65 +++++++++------------- .../beam/runners/dataflow/DataflowRunner.java | 6 +- .../beam/runners/dataflow/ReadTranslator.java | 2 +- .../beam/runners/dataflow/TransformTranslator.java | 17 +++--- .../dataflow/DataflowPipelineTranslatorTest.java | 6 +- .../beam/runners/dataflow/DataflowRunnerTest.java | 7 +-- .../java/org/apache/beam/sdk/util/DoFnInfo.java | 22 ++------ .../apache/beam/fn/harness/FnApiDoFnRunner.java | 64 ++++++++------------- .../beam/fn/harness/FnApiDoFnRunnerTest.java | 18 ++---- 11 files changed, 81 insertions(+), 130 deletions(-) diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index dc9e802..1656733 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -36,7 +36,7 @@ processResources { filter org.apache.tools.ant.filters.ReplaceTokens, tokens: [ 'dataflow.legacy_environment_major_version' : '6', 'dataflow.fnapi_environment_major_version' : '1', - 'dataflow.container_version' : 'beam-master-20180117' + 'dataflow.container_version' : 'beam-master-20180122' ] } diff --git a/runners/google-cloud-dataflow-java/pom.xml b/runners/google-cloud-dataflow-java/pom.xml index 952a140..8006ecc 100644 --- a/runners/google-cloud-dataflow-java/pom.xml +++ b/runners/google-cloud-dataflow-java/pom.xml @@ -33,7 +33,7 @@ <packaging>jar</packaging> <properties> - <dataflow.container_version>beam-master-20180117</dataflow.container_version> + <dataflow.container_version>beam-master-20180122</dataflow.container_version> <dataflow.fnapi_environment_major_version>1</dataflow.fnapi_environment_major_version> <dataflow.legacy_environment_major_version>6</dataflow.legacy_environment_major_version> </properties> diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index ce25745..9ff8a45 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -43,9 +43,6 @@ import com.google.api.services.dataflow.model.Step; import com.google.api.services.dataflow.model.WorkerPool; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Supplier; -import com.google.common.collect.BiMap; -import com.google.common.collect.ImmutableBiMap; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.protobuf.TextFormat; import java.util.ArrayList; @@ -606,18 +603,18 @@ public class DataflowPipelineTranslator { } @Override - public long addOutput(PCollection<?> value) { + public void addOutput(String name, PCollection<?> value) { translator.producers.put(value, translator.currentTransform); // Wrap the PCollection element Coder inside a WindowedValueCoder. Coder<?> coder = WindowedValue.getFullCoder( value.getCoder(), value.getWindowingStrategy().getWindowFn().windowCoder()); - return addOutput(value, coder); + addOutput(name, value, coder); } @Override - public long addCollectionToSingletonOutput( - PCollection<?> inputValue, PCollectionView<?> outputValue) { + public void addCollectionToSingletonOutput( + PCollection<?> inputValue, String outputName, PCollectionView<?> outputValue) { translator.producers.put(outputValue, translator.currentTransform); Coder<?> inputValueCoder = checkNotNull(translator.outputCoders.get(inputValue)); @@ -630,7 +627,7 @@ public class DataflowPipelineTranslator { // IterableCoder of the inputValueCoder. This is a property // of the backend "CollectionToSingleton" step. Coder<?> outputValueCoder = IterableCoder.of(inputValueCoder); - return addOutput(outputValue, outputValueCoder); + addOutput(outputName, outputValue, outputValueCoder); } /** @@ -638,9 +635,8 @@ public class DataflowPipelineTranslator { * Dataflow step, producing the specified output {@code PValue} * with the given {@code Coder} (if not {@code null}). */ - private long addOutput(PValue value, Coder<?> valueCoder) { - long id = translator.idGenerator.get(); - translator.registerOutputName(value, Long.toString(id)); + private void addOutput(String name, PValue value, Coder<?> valueCoder) { + translator.registerOutputName(value, name); Map<String, Object> properties = getProperties(); @Nullable List<Map<String, Object>> outputInfoList = null; @@ -657,7 +653,7 @@ public class DataflowPipelineTranslator { } Map<String, Object> outputInfo = new HashMap<>(); - addString(outputInfo, PropertyNames.OUTPUT_NAME, Long.toString(id)); + addString(outputInfo, PropertyNames.OUTPUT_NAME, name); String stepName = getString(properties, PropertyNames.USER_NAME); String generatedName = String.format( @@ -677,7 +673,6 @@ public class DataflowPipelineTranslator { } outputInfoList.add(outputInfo); - return id; } private void addDisplayData(Step step, String stepName, HasDisplayData hasDisplayData) { @@ -721,7 +716,8 @@ public class DataflowPipelineTranslator { context.addStep(transform, "CollectionToSingleton"); PCollection<ElemT> input = context.getInput(transform); stepContext.addInput(PropertyNames.PARALLEL_INPUT, input); - stepContext.addCollectionToSingletonOutput(input, transform.getView()); + stepContext.addCollectionToSingletonOutput( + input, PropertyNames.OUTPUT, transform.getView()); } }); @@ -739,7 +735,8 @@ public class DataflowPipelineTranslator { context.addStep(transform, "CollectionToSingleton"); PCollection<ElemT> input = context.getInput(transform); stepContext.addInput(PropertyNames.PARALLEL_INPUT, input); - stepContext.addCollectionToSingletonOutput(input, transform.getView()); + stepContext.addCollectionToSingletonOutput( + input, PropertyNames.OUTPUT, transform.getView()); } }); @@ -773,7 +770,7 @@ public class DataflowPipelineTranslator { stepContext.addEncodingInput(fn.getAccumulatorCoder()); stepContext.addInput( PropertyNames.SERIALIZED_FN, byteArrayToJsonString(serializeToByteArray(fn))); - stepContext.addOutput(context.getOutput(primitiveTransform)); + stepContext.addOutput(PropertyNames.OUTPUT, context.getOutput(primitiveTransform)); } }); @@ -797,7 +794,7 @@ public class DataflowPipelineTranslator { input, context.getProducer(input))); } stepContext.addInput(PropertyNames.INPUTS, inputs); - stepContext.addOutput(context.getOutput(transform)); + stepContext.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); } }); @@ -814,7 +811,7 @@ public class DataflowPipelineTranslator { StepTranslationContext stepContext = context.addStep(transform, "GroupByKey"); PCollection<KV<K1, KV<K2, V>>> input = context.getInput(transform); stepContext.addInput(PropertyNames.PARALLEL_INPUT, input); - stepContext.addOutput(context.getOutput(transform)); + stepContext.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); stepContext.addInput(PropertyNames.SORT_VALUES, true); // TODO: Add support for combiner lifting once the need arises. @@ -835,7 +832,7 @@ public class DataflowPipelineTranslator { StepTranslationContext stepContext = context.addStep(transform, "GroupByKey"); PCollection<KV<K, V>> input = context.getInput(transform); stepContext.addInput(PropertyNames.PARALLEL_INPUT, input); - stepContext.addOutput(context.getOutput(transform)); + stepContext.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy(); boolean isStreaming = @@ -870,7 +867,6 @@ public class DataflowPipelineTranslator { StepTranslationContext stepContext = context.addStep(transform, "ParallelDo"); translateInputs( stepContext, context.getInput(transform), transform.getSideInputs(), context); - BiMap<Long, TupleTag<?>> outputMap = translateOutputs(context.getOutputs(transform), stepContext); String ptransformId = context.getSdkComponents().getPTransformIdOrThrow(context.getCurrentTransform()); @@ -882,8 +878,7 @@ public class DataflowPipelineTranslator { transform.getSideInputs(), context.getInput(transform).getCoder(), context, - outputMap.inverse().get(transform.getMainOutputTag()), - outputMap); + transform.getMainOutputTag()); } }); @@ -901,7 +896,8 @@ public class DataflowPipelineTranslator { StepTranslationContext stepContext = context.addStep(transform, "ParallelDo"); translateInputs( stepContext, context.getInput(transform), transform.getSideInputs(), context); - long mainOutput = stepContext.addOutput(context.getOutput(transform)); + stepContext.addOutput( + transform.getMainOutputTag().getId(), context.getOutput(transform)); String ptransformId = context.getSdkComponents().getPTransformIdOrThrow(context.getCurrentTransform()); translateFn( @@ -912,9 +908,7 @@ public class DataflowPipelineTranslator { transform.getSideInputs(), context.getInput(transform).getCoder(), context, - mainOutput, - ImmutableMap.<Long, TupleTag<?>>of( - mainOutput, new TupleTag<>(PropertyNames.OUTPUT))); + transform.getMainOutputTag()); } }); @@ -930,7 +924,7 @@ public class DataflowPipelineTranslator { StepTranslationContext stepContext = context.addStep(transform, "Bucket"); PCollection<T> input = context.getInput(transform); stepContext.addInput(PropertyNames.PARALLEL_INPUT, input); - stepContext.addOutput(context.getOutput(transform)); + stepContext.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); WindowingStrategy<?, ?> strategy = context.getOutput(transform).getWindowingStrategy(); byte[] serializedBytes = serializeWindowingStrategy(strategy); @@ -964,7 +958,6 @@ public class DataflowPipelineTranslator { translateInputs( stepContext, context.getInput(transform), transform.getSideInputs(), context); - BiMap<Long, TupleTag<?>> outputMap = translateOutputs(context.getOutputs(transform), stepContext); stepContext.addInput( PropertyNames.SERIALIZED_FN, @@ -975,8 +968,7 @@ public class DataflowPipelineTranslator { transform.getInputWindowingStrategy(), transform.getSideInputs(), transform.getElementCoder(), - outputMap.inverse().get(transform.getMainOutputTag()), - outputMap)))); + transform.getMainOutputTag())))); stepContext.addInput( PropertyNames.RESTRICTION_CODER, CloudObjects.asCloudObject(transform.getRestrictionCoder())); @@ -1017,8 +1009,7 @@ public class DataflowPipelineTranslator { Iterable<PCollectionView<?>> sideInputs, Coder inputCoder, TranslationContext context, - long mainOutput, - Map<Long, TupleTag<?>> outputMap) { + TupleTag<?> mainOutput) { DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass()); if (signature.processElement().isSplittable()) { @@ -1053,8 +1044,7 @@ public class DataflowPipelineTranslator { windowingStrategy, sideInputs, inputCoder, - mainOutput, - outputMap)))); + mainOutput)))); } // Setting USES_KEYED_STATE will cause an ungrouped shuffle, which works @@ -1065,19 +1055,16 @@ public class DataflowPipelineTranslator { } } - private static BiMap<Long, TupleTag<?>> translateOutputs( + private static void translateOutputs( Map<TupleTag<?>, PValue> outputs, StepTranslationContext stepContext) { - ImmutableBiMap.Builder<Long, TupleTag<?>> mapBuilder = ImmutableBiMap.builder(); for (Map.Entry<TupleTag<?>, PValue> taggedOutput : outputs.entrySet()) { TupleTag<?> tag = taggedOutput.getKey(); checkArgument(taggedOutput.getValue() instanceof PCollection, "Non %s returned from Multi-output %s", PCollection.class.getSimpleName(), stepContext); - PCollection<?> output = (PCollection<?>) taggedOutput.getValue(); - mapBuilder.put(stepContext.addOutput(output), tag); + stepContext.addOutput(tag.getId(), (PCollection<?>) taggedOutput.getValue()); } - return mapBuilder.build(); } } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 7693392..7eeeffc 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -1084,7 +1084,7 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { byteArrayToJsonString( serializeToByteArray(new IdentityMessageFn()))); } - stepContext.addOutput(context.getOutput(transform)); + stepContext.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); } } @@ -1298,7 +1298,7 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { StepTranslationContext stepContext = context.addStep(transform, "ParallelRead"); stepContext.addInput(PropertyNames.FORMAT, "pubsub"); stepContext.addInput(PropertyNames.PUBSUB_SUBSCRIPTION, "_starting_signal/"); - stepContext.addOutput(context.getOutput(transform)); + stepContext.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); } else { StepTranslationContext stepContext = context.addStep(transform, "ParallelRead"); stepContext.addInput(PropertyNames.FORMAT, "impulse"); @@ -1314,7 +1314,7 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { } stepContext.addInput( PropertyNames.IMPULSE_ELEMENT, byteArrayToJsonString(encodedImpulse)); - stepContext.addOutput(context.getOutput(transform)); + stepContext.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); } } } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReadTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReadTranslator.java index 693748a..994ced8 100755 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReadTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReadTranslator.java @@ -51,7 +51,7 @@ class ReadTranslator implements TransformTranslator<Read.Bounded<?>> { PropertyNames.SOURCE_STEP_INPUT, cloudSourceToDictionary( CustomSources.serializeToCloudSource(source, context.getPipelineOptions()))); - stepContext.addOutput(context.getOutput(transform)); + stepContext.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java index 7fec67a..3ea97b5 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java @@ -102,8 +102,8 @@ public interface TransformTranslator<TransformT extends PTransform> { * * <p>The input {@link PValue} must have already been produced by a step earlier in this * {@link Pipeline}. If the input value has not yet been produced yet (by a call to either - * {@link StepTranslationContext#addOutput(PCollection)} or - * {@link StepTranslationContext#addCollectionToSingletonOutput(PCollection, PCollectionView)}) + * {@link StepTranslationContext#addOutput} or + * {@link StepTranslationContext#addCollectionToSingletonOutput}) * this method will throw an exception. */ void addInput(String name, PInput value); @@ -115,18 +115,19 @@ public interface TransformTranslator<TransformT extends PTransform> { void addInput(String name, List<? extends Map<String, Object>> elements); /** - * Adds a primitive output to this Dataflow step, producing the specified output {@code PValue}, - * including its {@code Coder} if a {@code TypedPValue}. If the {@code PValue} is a {@code - * PCollection}, wraps its coder inside a {@code WindowedValueCoder}. Returns a pipeline level - * unique id. + * Adds a primitive output to this Dataflow step with the given name as the local output name, + * producing the specified output {@code PValue}, including its {@code Coder} if a + * {@code TypedPValue}. If the {@code PValue} is a {@code PCollection}, wraps its coder + * inside a {@code WindowedValueCoder}. */ - long addOutput(PCollection<?> value); + void addOutput(String name, PCollection<?> value); /** * Adds an output to this {@code CollectionToSingleton} Dataflow step, consuming the specified * input {@code PValue} and producing the specified output {@code PValue}. This step requires * special treatment for its output encoding. Returns a pipeline level unique id. */ - long addCollectionToSingletonOutput(PCollection<?> inputValue, PCollectionView<?> outputValue); + void addCollectionToSingletonOutput(PCollection<?> inputValue, + String outputName, PCollectionView<?> outputValue); } } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java index 181edd4..3e3bf0d 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java @@ -1030,17 +1030,17 @@ public class DataflowPipelineTranslatorTest implements Serializable { private static void assertAllStepOutputsHaveUniqueIds(Job job) throws Exception { - List<Long> outputIds = new ArrayList<>(); + List<String> outputIds = new ArrayList<>(); for (Step step : job.getSteps()) { List<Map<String, Object>> outputInfoList = (List<Map<String, Object>>) step.getProperties().get(PropertyNames.OUTPUT_INFO); if (outputInfoList != null) { for (Map<String, Object> outputInfo : outputInfoList) { - outputIds.add(Long.parseLong(Structs.getString(outputInfo, PropertyNames.OUTPUT_NAME))); + outputIds.add(Structs.getString(outputInfo, PropertyNames.OUTPUT_NAME)); } } } - Set<Long> uniqueOutputNames = new HashSet<>(outputIds); + Set<String> uniqueOutputNames = new HashSet<>(outputIds); outputIds.removeAll(uniqueOutputNames); assertTrue(String.format("Found duplicate output ids %s", outputIds), outputIds.size() == 0); diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java index 7fc7aa7..23eddca 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java @@ -83,6 +83,7 @@ import org.apache.beam.runners.dataflow.DataflowRunner.StreamingShardedWriteFact import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; import org.apache.beam.runners.dataflow.options.DataflowPipelineWorkerPoolOptions; +import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; @@ -1149,15 +1150,13 @@ public class DataflowRunnerTest implements Serializable { new TransformTranslator<TestTransform>() { @SuppressWarnings("unchecked") @Override - public void translate( - TestTransform transform, - TranslationContext context) { + public void translate(TestTransform transform, TranslationContext context) { transform.translated = true; // Note: This is about the minimum needed to fake out a // translation. This obviously isn't a real translation. StepTranslationContext stepContext = context.addStep(transform, "TestTranslate"); - stepContext.addOutput(context.getOutput(transform)); + stepContext.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); } }); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DoFnInfo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DoFnInfo.java index 0800b21..f0c5a26 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DoFnInfo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DoFnInfo.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.util; import java.io.Serializable; -import java.util.Map; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.values.PCollectionView; @@ -36,8 +35,7 @@ public class DoFnInfo<InputT, OutputT> implements Serializable { private final WindowingStrategy<?, ?> windowingStrategy; private final Iterable<PCollectionView<?>> sideInputViews; private final Coder<InputT> inputCoder; - private final long mainOutput; - private final Map<Long, TupleTag<?>> outputMap; + private final TupleTag<OutputT> mainOutput; /** * Creates a {@link DoFnInfo} for the given {@link DoFn}. @@ -47,10 +45,9 @@ public class DoFnInfo<InputT, OutputT> implements Serializable { WindowingStrategy<?, ?> windowingStrategy, Iterable<PCollectionView<?>> sideInputViews, Coder<InputT> inputCoder, - long mainOutput, - Map<Long, TupleTag<?>> outputMap) { + TupleTag<OutputT> mainOutput) { return new DoFnInfo<>( - doFn, windowingStrategy, sideInputViews, inputCoder, mainOutput, outputMap); + doFn, windowingStrategy, sideInputViews, inputCoder, mainOutput); } public DoFnInfo<InputT, OutputT> withFn(DoFn<InputT, OutputT> newFn) { @@ -58,8 +55,7 @@ public class DoFnInfo<InputT, OutputT> implements Serializable { windowingStrategy, sideInputViews, inputCoder, - mainOutput, - outputMap); + mainOutput); } private DoFnInfo( @@ -67,14 +63,12 @@ public class DoFnInfo<InputT, OutputT> implements Serializable { WindowingStrategy<?, ?> windowingStrategy, Iterable<PCollectionView<?>> sideInputViews, Coder<InputT> inputCoder, - long mainOutput, - Map<Long, TupleTag<?>> outputMap) { + TupleTag<OutputT> mainOutput) { this.doFn = doFn; this.windowingStrategy = windowingStrategy; this.sideInputViews = sideInputViews; this.inputCoder = inputCoder; this.mainOutput = mainOutput; - this.outputMap = outputMap; } /** Returns the embedded function. */ @@ -94,11 +88,7 @@ public class DoFnInfo<InputT, OutputT> implements Serializable { return inputCoder; } - public long getMainOutput() { + public TupleTag<OutputT> getMainOutput() { return mainOutput; } - - public Map<Long, TupleTag<?>> getOutputMap() { - return outputMap; - } } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index 1d93eed..d983082 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -17,17 +17,16 @@ */ package org.apache.beam.fn.harness; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import com.google.auto.service.AutoService; import com.google.common.base.Suppliers; -import com.google.common.collect.Collections2; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Iterables; +import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimap; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; @@ -35,10 +34,8 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.Map; -import java.util.Objects; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -94,6 +91,7 @@ import org.apache.beam.sdk.util.DoFnInfo; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; @@ -141,15 +139,15 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp Consumer<ThrowingRunnable> addFinishFunction) { // For every output PCollection, create a map from output name to Consumer - ImmutableMap.Builder<String, Collection<FnDataReceiver<WindowedValue<?>>>> - outputMapBuilder = ImmutableMap.builder(); + ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> + tagToOutputMapBuilder = ImmutableListMultimap.builder(); for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) { - outputMapBuilder.put( - entry.getKey(), + tagToOutputMapBuilder.putAll( + new TupleTag<>(entry.getKey()), pCollectionIdsToConsumers.get(entry.getValue())); } - ImmutableMap<String, Collection<FnDataReceiver<WindowedValue<?>>>> outputMap = - outputMapBuilder.build(); + ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToOutputMap = + tagToOutputMapBuilder.build(); // Get the DoFnInfo from the serialized blob. ByteString serializedFn = pTransform.getSpec().getPayload(); @@ -157,30 +155,8 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp DoFnInfo<InputT, OutputT> doFnInfo = (DoFnInfo) SerializableUtils.deserializeFromByteArray( serializedFn.toByteArray(), "DoFnInfo"); - // Verify that the DoFnInfo tag to output map matches the output map on the PTransform. - checkArgument( - Objects.equals( - new HashSet<>(Collections2.transform(outputMap.keySet(), Long::parseLong)), - doFnInfo.getOutputMap().keySet()), - "Unexpected mismatch between transform output map %s and DoFnInfo output map %s.", - outputMap.keySet(), - doFnInfo.getOutputMap()); - - ImmutableMultimap.Builder<TupleTag<?>, - FnDataReceiver<WindowedValue<?>>> tagToOutputMapBuilder = - ImmutableMultimap.builder(); - for (Map.Entry<Long, TupleTag<?>> entry : doFnInfo.getOutputMap().entrySet()) { @SuppressWarnings({"unchecked", "rawtypes"}) - Collection<FnDataReceiver<WindowedValue<?>>> consumers = - outputMap.get(Long.toString(entry.getKey())); - tagToOutputMapBuilder.putAll(entry.getValue(), consumers); - } - - ImmutableMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToOutputMap = - tagToOutputMapBuilder.build(); - - @SuppressWarnings({"unchecked", "rawtypes"}) - DoFnRunner<InputT, OutputT> runner = new FnApiDoFnRunner<>( + DoFnRunner<InputT, OutputT> runner = new FnApiDoFnRunner<InputT, OutputT>( pipelineOptions, beamFnStateClient, pTransformId, @@ -188,7 +164,7 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp doFnInfo.getDoFn(), doFnInfo.getInputCoder(), (Collection<FnDataReceiver<WindowedValue<OutputT>>>) (Collection) - tagToOutputMap.get(doFnInfo.getOutputMap().get(doFnInfo.getMainOutput())), + tagToOutputMap.get(doFnInfo.getMainOutput()), tagToOutputMap, doFnInfo.getWindowingStrategy()); @@ -244,13 +220,13 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp throw new IllegalArgumentException("Malformed ParDoPayload", exn); } - ImmutableMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> - tagToConsumerBuilder = ImmutableMultimap.builder(); + ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> + tagToConsumerBuilder = ImmutableListMultimap.builder(); for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) { tagToConsumerBuilder.putAll( new TupleTag<>(entry.getKey()), pCollectionIdsToConsumers.get(entry.getValue())); } - Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToConsumer = + ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToConsumer = tagToConsumerBuilder.build(); @SuppressWarnings({"unchecked", "rawtypes"}) @@ -980,12 +956,18 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp "Accessing state in unkeyed context. Current element is not a KV: %s.", currentElement); checkState( - inputCoder instanceof KvCoder, - "Accessing state in unkeyed context. No keyed coder found."); + // TODO: Stop passing windowed value coders within PCollections. + inputCoder instanceof KvCoder + || (inputCoder instanceof WindowedValueCoder + && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder)), + "Accessing state in unkeyed context. Keyed coder expected but found %s.", + inputCoder); ByteString.Output encodedKeyOut = ByteString.newOutput(); - Coder<K> keyCoder = ((KvCoder<K, ?>) inputCoder).getKeyCoder(); + Coder<K> keyCoder = inputCoder instanceof WindowedValueCoder + ? ((KvCoder<K, ?>) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder() + : ((KvCoder<K, ?>) inputCoder).getKeyCoder(); try { keyCoder.encode(((KV<K, ?>) currentElement.getValue()).getKey(), encodedKeyOut); } catch (IOException e) { diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java index 5b17bf3..70aca2e 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java @@ -108,18 +108,13 @@ public class FnApiDoFnRunnerTest { @Test public void testCreatingAndProcessingDoFn() throws Exception { String pTransformId = "pTransformId"; - String mainOutputId = "101"; - String additionalOutputId = "102"; DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn( new TestDoFn(), WindowingStrategy.globalDefault(), ImmutableList.of(), StringUtf8Coder.of(), - Long.parseLong(mainOutputId), - ImmutableMap.of( - Long.parseLong(mainOutputId), TestDoFn.mainOutput, - Long.parseLong(additionalOutputId), TestDoFn.additionalOutput)); + TestDoFn.mainOutput); RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN) @@ -129,8 +124,8 @@ public class FnApiDoFnRunnerTest { .setSpec(functionSpec) .putInputs("inputA", "inputATarget") .putInputs("inputB", "inputBTarget") - .putOutputs(mainOutputId, "mainOutputTarget") - .putOutputs(additionalOutputId, "additionalOutputTarget") + .putOutputs(TestDoFn.mainOutput.getId(), "mainOutputTarget") + .putOutputs(TestDoFn.additionalOutput.getId(), "additionalOutputTarget") .build(); List<WindowedValue<String>> mainOutputValues = new ArrayList<>(); @@ -278,15 +273,12 @@ public class FnApiDoFnRunnerTest { @Test public void testUsingUserState() throws Exception { - String mainOutputId = "101"; - DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn( new TestStatefulDoFn(), WindowingStrategy.globalDefault(), ImmutableList.of(), KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()), - Long.parseLong(mainOutputId), - ImmutableMap.of(Long.parseLong(mainOutputId), new TupleTag<String>("mainOutput"))); + new TupleTag<>("mainOutput")); RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN) @@ -295,7 +287,7 @@ public class FnApiDoFnRunnerTest { RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() .setSpec(functionSpec) .putInputs("input", "inputTarget") - .putOutputs(mainOutputId, "mainOutputTarget") + .putOutputs("mainOutput", "mainOutputTarget") .build(); FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of( -- To stop receiving notification emails like this one, please contact lc...@apache.org.