Include Additional PTransform inputs in Transform Nodes Add the value of PTransform.getAdditionalInputs in the inputs of a TransformHierarchy node.
Fork the Node constructor to reduce nullability This slightly simplifies the constructor implementation(s). Update the DirectRunner to track main inputs instead of all inputs. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/9336230d Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/9336230d Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/9336230d Branch: refs/heads/master Commit: 9336230d2a5c18bae89908bcd60db8ea96b7906d Parents: b633abe Author: Thomas Groh <[email protected]> Authored: Tue Apr 4 17:43:48 2017 -0700 Committer: Thomas Groh <[email protected]> Committed: Mon May 22 18:12:11 2017 -0700 ---------------------------------------------------------------------- .../apex/translation/TranslationContext.java | 4 +- .../core/construction/TransformInputs.java | 50 ++++++ .../core/construction/TransformInputsTest.java | 166 +++++++++++++++++++ .../beam/runners/direct/DirectGraphVisitor.java | 15 +- .../runners/direct/ParDoEvaluatorFactory.java | 9 +- ...littableProcessElementsEvaluatorFactory.java | 2 + .../direct/StatefulParDoEvaluatorFactory.java | 1 + .../beam/runners/direct/WatermarkManager.java | 17 +- .../beam/runners/direct/ParDoEvaluatorTest.java | 6 +- .../flink/FlinkBatchTranslationContext.java | 3 +- .../flink/FlinkStreamingTranslationContext.java | 3 +- .../spark/translation/EvaluationContext.java | 4 +- .../beam/sdk/runners/TransformHierarchy.java | 28 +++- 13 files changed, 280 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/9336230d/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java index aff3863..94d13e1 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java @@ -34,6 +34,7 @@ import org.apache.beam.runners.apex.translation.utils.ApexStateInternals; import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateBackend; import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple; import org.apache.beam.runners.apex.translation.utils.CoderAdapterStreamCodec; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; @@ -93,7 +94,8 @@ class TranslationContext { } public <InputT extends PValue> InputT getInput() { - return (InputT) Iterables.getOnlyElement(getCurrentTransform().getInputs().values()); + return (InputT) + Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(getCurrentTransform())); } public Map<TupleTag<?>, PValue> getOutputs() { http://git-wip-us.apache.org/repos/asf/beam/blob/9336230d/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java new file mode 100644 index 0000000..2baf93a --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.common.collect.ImmutableList; +import java.util.Collection; +import java.util.Map; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; + +/** Utilities for extracting subsets of inputs from an {@link AppliedPTransform}. */ +public class TransformInputs { + /** + * Gets all inputs of the {@link AppliedPTransform} that are not returned by {@link + * PTransform#getAdditionalInputs()}. + */ + public static Collection<PValue> nonAdditionalInputs(AppliedPTransform<?, ?, ?> application) { + ImmutableList.Builder<PValue> mainInputs = ImmutableList.builder(); + PTransform<?, ?> transform = application.getTransform(); + for (Map.Entry<TupleTag<?>, PValue> input : application.getInputs().entrySet()) { + if (!transform.getAdditionalInputs().containsKey(input.getKey())) { + mainInputs.add(input.getValue()); + } + } + checkArgument( + !mainInputs.build().isEmpty() || application.getInputs().isEmpty(), + "Expected at least one main input if any inputs exist"); + return mainInputs.build(); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/9336230d/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java new file mode 100644 index 0000000..f5b2c11 --- /dev/null +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction; + +import static org.junit.Assert.assertThat; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link TransformInputs}. */ +@RunWith(JUnit4.class) +public class TransformInputsTest { + @Rule public TestPipeline pipeline = TestPipeline.create().enableAbandonedNodeEnforcement(false); + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void nonAdditionalInputsWithNoInputSucceeds() { + AppliedPTransform<PInput, POutput, TestTransform> transform = + AppliedPTransform.of( + "input-free", + Collections.<TupleTag<?>, PValue>emptyMap(), + Collections.<TupleTag<?>, PValue>emptyMap(), + new TestTransform(), + pipeline); + + assertThat(TransformInputs.nonAdditionalInputs(transform), Matchers.<PValue>empty()); + } + + @Test + public void nonAdditionalInputsWithOneMainInputSucceeds() { + PCollection<Long> input = pipeline.apply(GenerateSequence.from(1L)); + AppliedPTransform<PInput, POutput, TestTransform> transform = + AppliedPTransform.of( + "input-single", + Collections.<TupleTag<?>, PValue>singletonMap(new TupleTag<Long>() {}, input), + Collections.<TupleTag<?>, PValue>emptyMap(), + new TestTransform(), + pipeline); + + assertThat( + TransformInputs.nonAdditionalInputs(transform), Matchers.<PValue>containsInAnyOrder(input)); + } + + @Test + public void nonAdditionalInputsWithMultipleNonAdditionalInputsSucceeds() { + Map<TupleTag<?>, PValue> allInputs = new HashMap<>(); + PCollection<Integer> mainInts = pipeline.apply("MainInput", Create.of(12, 3)); + allInputs.put(new TupleTag<Integer>() {}, mainInts); + PCollection<Void> voids = pipeline.apply("VoidInput", Create.empty(VoidCoder.of())); + allInputs.put(new TupleTag<Void>() {}, voids); + AppliedPTransform<PInput, POutput, TestTransform> transform = + AppliedPTransform.of( + "additional-free", + allInputs, + Collections.<TupleTag<?>, PValue>emptyMap(), + new TestTransform(), + pipeline); + + assertThat( + TransformInputs.nonAdditionalInputs(transform), + Matchers.<PValue>containsInAnyOrder(voids, mainInts)); + } + + @Test + public void nonAdditionalInputsWithAdditionalInputsSucceeds() { + Map<TupleTag<?>, PValue> additionalInputs = new HashMap<>(); + additionalInputs.put(new TupleTag<String>() {}, pipeline.apply(Create.of("1, 2", "3"))); + additionalInputs.put(new TupleTag<Long>() {}, pipeline.apply(GenerateSequence.from(3L))); + + Map<TupleTag<?>, PValue> allInputs = new HashMap<>(); + PCollection<Integer> mainInts = pipeline.apply("MainInput", Create.of(12, 3)); + allInputs.put(new TupleTag<Integer>() {}, mainInts); + PCollection<Void> voids = pipeline.apply("VoidInput", Create.empty(VoidCoder.of())); + allInputs.put( + new TupleTag<Void>() {}, voids); + allInputs.putAll(additionalInputs); + + AppliedPTransform<PInput, POutput, TestTransform> transform = + AppliedPTransform.of( + "additional", + allInputs, + Collections.<TupleTag<?>, PValue>emptyMap(), + new TestTransform(additionalInputs), + pipeline); + + assertThat( + TransformInputs.nonAdditionalInputs(transform), + Matchers.<PValue>containsInAnyOrder(mainInts, voids)); + } + + @Test + public void nonAdditionalInputsWithOnlyAdditionalInputsThrows() { + Map<TupleTag<?>, PValue> additionalInputs = new HashMap<>(); + additionalInputs.put(new TupleTag<String>() {}, pipeline.apply(Create.of("1, 2", "3"))); + additionalInputs.put(new TupleTag<Long>() {}, pipeline.apply(GenerateSequence.from(3L))); + + AppliedPTransform<PInput, POutput, TestTransform> transform = + AppliedPTransform.of( + "additional-only", + additionalInputs, + Collections.<TupleTag<?>, PValue>emptyMap(), + new TestTransform(additionalInputs), + pipeline); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("at least one"); + TransformInputs.nonAdditionalInputs(transform); + } + + private static class TestTransform extends PTransform<PInput, POutput> { + private final Map<TupleTag<?>, PValue> additionalInputs; + + private TestTransform() { + this(Collections.<TupleTag<?>, PValue>emptyMap()); + } + + private TestTransform(Map<TupleTag<?>, PValue> additionalInputs) { + this.additionalInputs = additionalInputs; + } + + @Override + public POutput expand(PInput input) { + return PDone.in(input.getPipeline()); + } + + @Override + public Map<TupleTag<?>, PValue> getAdditionalInputs() { + return additionalInputs; + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/9336230d/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java index 01204e3..ed4282b 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java @@ -21,10 +21,12 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ListMultimap; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.runners.AppliedPTransform; @@ -34,6 +36,8 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Tracks the {@link AppliedPTransform AppliedPTransforms} that consume each {@link PValue} in the @@ -41,6 +45,7 @@ import org.apache.beam.sdk.values.PValue; * input after the upstream transform has produced and committed output. */ class DirectGraphVisitor extends PipelineVisitor.Defaults { + private static final Logger LOG = LoggerFactory.getLogger(DirectGraphVisitor.class); private Map<POutput, AppliedPTransform<?, ?, ?>> producers = new HashMap<>(); @@ -83,7 +88,15 @@ class DirectGraphVisitor extends PipelineVisitor.Defaults { if (node.getInputs().isEmpty()) { rootTransforms.add(appliedTransform); } else { - for (PValue value : node.getInputs().values()) { + Collection<PValue> mainInputs = + TransformInputs.nonAdditionalInputs(node.toAppliedPTransform(getPipeline())); + if (!mainInputs.containsAll(node.getInputs().values())) { + LOG.debug( + "Inputs reduced to {} from {} by removing additional inputs", + mainInputs, + node.getInputs().values()); + } + for (PValue value : mainInputs) { primitiveConsumers.put(value, appliedTransform); } } http://git-wip-us.apache.org/repos/asf/beam/blob/9336230d/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java index 74470bf..c52091e 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java @@ -20,7 +20,6 @@ package org.apache.beam.runners.direct; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; -import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -79,6 +78,7 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator (TransformEvaluator<T>) createEvaluator( (AppliedPTransform) application, + (PCollection<InputT>) inputBundle.getPCollection(), inputBundle.getKey(), doFn, transform.getSideInputs(), @@ -102,6 +102,7 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator @SuppressWarnings({"unchecked", "rawtypes"}) DoFnLifecycleManagerRemovingTransformEvaluator<InputT> createEvaluator( AppliedPTransform<PCollection<InputT>, PCollectionTuple, ?> application, + PCollection<InputT> mainInput, StructuralKey<?> inputBundleKey, DoFn<InputT, OutputT> doFn, List<PCollectionView<?>> sideInputs, @@ -120,6 +121,7 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator createParDoEvaluator( application, inputBundleKey, + mainInput, sideInputs, mainOutputTag, additionalOutputTags, @@ -132,6 +134,7 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator ParDoEvaluator<InputT> createParDoEvaluator( AppliedPTransform<PCollection<InputT>, PCollectionTuple, ?> application, StructuralKey<?> key, + PCollection<InputT> mainInput, List<PCollectionView<?>> sideInputs, TupleTag<OutputT> mainOutputTag, List<TupleTag<?>> additionalOutputTags, @@ -144,8 +147,7 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator evaluationContext, stepContext, application, - ((PCollection<InputT>) Iterables.getOnlyElement(application.getInputs().values())) - .getWindowingStrategy(), + mainInput.getWindowingStrategy(), fn, key, sideInputs, @@ -173,5 +175,4 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator } return pcs; } - } http://git-wip-us.apache.org/repos/asf/beam/blob/9336230d/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java index dc85d87..4e7f4db 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java @@ -116,6 +116,8 @@ class SplittableProcessElementsEvaluatorFactory< delegateFactory.createParDoEvaluator( application, inputBundle.getKey(), + (PCollection<KeyedWorkItem<String, ElementAndRestriction<InputT, RestrictionT>>>) + inputBundle.getPCollection(), transform.getSideInputs(), transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), http://git-wip-us.apache.org/repos/asf/beam/blob/9336230d/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java index 985c3be..e22edd1 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java @@ -117,6 +117,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo DoFnLifecycleManagerRemovingTransformEvaluator<KV<K, InputT>> delegateEvaluator = delegateFactory.createEvaluator( (AppliedPTransform) application, + (PCollection) inputBundle.getPCollection(), inputBundle.getKey(), doFn, application.getTransform().getUnderlyingParDo().getSideInputs(), http://git-wip-us.apache.org/repos/asf/beam/blob/9336230d/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java index 4f1b831..b15b52e 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java @@ -823,10 +823,11 @@ class WatermarkManager { inputWmsBuilder.add(THE_END_OF_TIME); } for (PValue pvalue : inputs.values()) { - Watermark producerOutputWatermark = - getTransformWatermark(graph.getProducer(pvalue)) - .synchronizedProcessingOutputWatermark; - inputWmsBuilder.add(producerOutputWatermark); + if (graph.getPrimitiveConsumers(pvalue).contains(transform)) { + Watermark producerOutputWatermark = + getTransformWatermark(graph.getProducer(pvalue)).synchronizedProcessingOutputWatermark; + inputWmsBuilder.add(producerOutputWatermark); + } } return inputWmsBuilder.build(); } @@ -838,9 +839,11 @@ class WatermarkManager { inputWatermarksBuilder.add(THE_END_OF_TIME); } for (PValue pvalue : inputs.values()) { - Watermark producerOutputWatermark = - getTransformWatermark(graph.getProducer(pvalue)).outputWatermark; - inputWatermarksBuilder.add(producerOutputWatermark); + if (graph.getPrimitiveConsumers(pvalue).contains(transform)) { + Watermark producerOutputWatermark = + getTransformWatermark(graph.getProducer(pvalue)).outputWatermark; + inputWatermarksBuilder.add(producerOutputWatermark); + } } List<Watermark> inputCollectionWatermarks = inputWatermarksBuilder.build(); return inputCollectionWatermarks; http://git-wip-us.apache.org/repos/asf/beam/blob/9336230d/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java index 286e44d..3b2a22e 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java @@ -98,7 +98,7 @@ public class ParDoEvaluatorTest { when(evaluationContext.createBundle(output)).thenReturn(outputBundle); ParDoEvaluator<Integer> evaluator = - createEvaluator(singletonView, fn, output); + createEvaluator(singletonView, fn, inputPc, output); IntervalWindow nonGlobalWindow = new IntervalWindow(new Instant(0), new Instant(10_000L)); WindowedValue<Integer> first = WindowedValue.valueInGlobalWindow(3); @@ -132,6 +132,7 @@ public class ParDoEvaluatorTest { private ParDoEvaluator<Integer> createEvaluator( PCollectionView<Integer> singletonView, RecorderFn fn, + PCollection<Integer> input, PCollection<Integer> output) { when( evaluationContext.createSideInputReader( @@ -156,8 +157,7 @@ public class ParDoEvaluatorTest { evaluationContext, stepContext, transform, - ((PCollection<?>) Iterables.getOnlyElement(transform.getInputs().values())) - .getWindowingStrategy(), + input.getWindowingStrategy(), fn, null /* key */, ImmutableList.<PCollectionView<?>>of(singletonView), http://git-wip-us.apache.org/repos/asf/beam/blob/9336230d/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java index 0439119..6e70198 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java @@ -20,6 +20,7 @@ package org.apache.beam.runners.flink; import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.Map; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; @@ -143,7 +144,7 @@ class FlinkBatchTranslationContext { @SuppressWarnings("unchecked") <T extends PValue> T getInput(PTransform<T, ?> transform) { - return (T) Iterables.getOnlyElement(currentTransform.getInputs().values()); + return (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform)); } Map<TupleTag<?>, PValue> getOutputs(PTransform<?, ?> transform) { http://git-wip-us.apache.org/repos/asf/beam/blob/9336230d/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java index ea5f6b3..74a5fb9 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java @@ -22,6 +22,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.Map; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; @@ -113,7 +114,7 @@ class FlinkStreamingTranslationContext { @SuppressWarnings("unchecked") public <T extends PValue> T getInput(PTransform<T, ?> transform) { - return (T) Iterables.getOnlyElement(currentTransform.getInputs().values()); + return (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform)); } public <T extends PInput> Map<TupleTag<?>, PValue> getInputs(PTransform<T, ?> transform) { http://git-wip-us.apache.org/repos/asf/beam/blob/9336230d/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java index 8102926..0c6c4d1 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -26,6 +26,7 @@ import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.sdk.Pipeline; @@ -103,7 +104,8 @@ public class EvaluationContext { public <T extends PValue> T getInput(PTransform<T, ?> transform) { @SuppressWarnings("unchecked") - T input = (T) Iterables.getOnlyElement(getInputs(transform).values()); + T input = + (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(getCurrentTransform())); return input; } http://git-wip-us.apache.org/repos/asf/beam/blob/9336230d/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java index 2f0e8ef..9d73b45 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java @@ -32,7 +32,6 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; -import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior; @@ -68,7 +67,7 @@ public class TransformHierarchy { producers = new HashMap<>(); producerInput = new HashMap<>(); unexpandedInputs = new HashMap<>(); - root = new Node(null, null, "", null); + root = new Node(); current = root; } @@ -253,25 +252,36 @@ public class TransformHierarchy { boolean finishedSpecifying = false; /** + * Creates the root-level node. The root level node has a null enclosing node, a null transform, + * an empty map of inputs, and a name equal to the empty string. + */ + private Node() { + this.enclosingNode = null; + this.transform = null; + this.fullName = ""; + this.inputs = Collections.emptyMap(); + } + + /** * Creates a new Node with the given parent and transform. * - * <p>EnclosingNode and transform may both be null for a root-level node, which holds all other - * nodes. - * * @param enclosingNode the composite node containing this node * @param transform the PTransform tracked by this node * @param fullName the fully qualified name of the transform * @param input the unexpanded input to the transform */ private Node( - @Nullable Node enclosingNode, - @Nullable PTransform<?, ?> transform, + Node enclosingNode, + PTransform<?, ?> transform, String fullName, - @Nullable PInput input) { + PInput input) { this.enclosingNode = enclosingNode; this.transform = transform; this.fullName = fullName; - this.inputs = input == null ? Collections.<TupleTag<?>, PValue>emptyMap() : input.expand(); + ImmutableMap.Builder<TupleTag<?>, PValue> inputs = ImmutableMap.builder(); + inputs.putAll(input.expand()); + inputs.putAll(transform.getAdditionalInputs()); + this.inputs = inputs.build(); } /**
