Repository: incubator-beam Updated Branches: refs/heads/master 8cb2689f8 -> 1abbb9007
Stop using Maps of Transforms in the DirectRunner Instead, add a "DirectGraph" class, which adds a layer of indirection to all lookup methods. Remove all remaining uses of getProducingTransformInternal, and instead use DirectGraph methods to obtain the producing transform. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/8162cd29 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/8162cd29 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/8162cd29 Branch: refs/heads/master Commit: 8162cd29d97ef307b6fac588f453e4e39d70fca7 Parents: 8cb2689 Author: Thomas Groh <tg...@google.com> Authored: Thu Dec 1 15:39:30 2016 -0800 Committer: Thomas Groh <tg...@google.com> Committed: Fri Dec 2 14:02:24 2016 -0800 ---------------------------------------------------------------------- .../direct/ConsumerTrackingPipelineVisitor.java | 108 +++++++------------ .../apache/beam/runners/direct/DirectGraph.java | 89 +++++++++++++++ .../beam/runners/direct/DirectRunner.java | 31 +++--- .../beam/runners/direct/EvaluationContext.java | 76 ++++--------- .../direct/ExecutorServiceParallelExecutor.java | 15 +-- .../ImmutabilityCheckingBundleFactory.java | 21 ++-- .../beam/runners/direct/WatermarkManager.java | 50 ++++----- .../ConsumerTrackingPipelineVisitorTest.java | 98 +++++------------ .../runners/direct/EvaluationContextTest.java | 25 ++--- .../ImmutabilityCheckingBundleFactoryTest.java | 6 +- .../runners/direct/WatermarkManagerTest.java | 23 ++-- 11 files changed, 252 insertions(+), 290 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8162cd29/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java index acfad16..b9e77c5 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java @@ -19,8 +19,8 @@ package org.apache.beam.runners.direct; import static com.google.common.base.Preconditions.checkState; -import java.util.ArrayList; -import java.util.Collection; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ListMultimap; import java.util.HashMap; import java.util.HashSet; import java.util.Map; @@ -33,6 +33,7 @@ import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; /** @@ -41,9 +42,13 @@ import org.apache.beam.sdk.values.PValue; * input after the upstream transform has produced and committed output. */ public class ConsumerTrackingPipelineVisitor extends PipelineVisitor.Defaults { - private Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers = new HashMap<>(); - private Collection<AppliedPTransform<?, ?, ?>> rootTransforms = new ArrayList<>(); - private Collection<PCollectionView<?>> views = new ArrayList<>(); + private Map<POutput, AppliedPTransform<?, ?, ?>> producers = new HashMap<>(); + + private ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers = + ArrayListMultimap.create(); + + private Set<PCollectionView<?>> views = new HashSet<>(); + private Set<AppliedPTransform<?, ?, ?>> rootTransforms = new HashSet<>(); private Map<AppliedPTransform<?, ?, ?>, String> stepNames = new HashMap<>(); private Set<PValue> toFinalize = new HashSet<>(); private int numTransforms = 0; @@ -81,81 +86,38 @@ public class ConsumerTrackingPipelineVisitor extends PipelineVisitor.Defaults { rootTransforms.add(appliedTransform); } else { for (PValue value : node.getInput().expand()) { - valueToConsumers.get(value).add(appliedTransform); + primitiveConsumers.put(value, appliedTransform); } } } - private AppliedPTransform<?, ?, ?> getAppliedTransform(TransformHierarchy.Node node) { - @SuppressWarnings({"rawtypes", "unchecked"}) - AppliedPTransform<?, ?, ?> application = AppliedPTransform.of( - node.getFullName(), node.getInput(), node.getOutput(), (PTransform) node.getTransform()); - return application; - } - - @Override + @Override public void visitValue(PValue value, TransformHierarchy.Node producer) { toFinalize.add(value); + + AppliedPTransform<?, ?, ?> appliedTransform = getAppliedTransform(producer); + if (!producers.containsKey(value)) { + producers.put(value, appliedTransform); + } for (PValue expandedValue : value.expand()) { - valueToConsumers.put(expandedValue, new ArrayList<AppliedPTransform<?, ?, ?>>()); if (expandedValue instanceof PCollectionView) { views.add((PCollectionView<?>) expandedValue); } - expandedValue.recordAsOutput(getAppliedTransform(producer)); + if (!producers.containsKey(expandedValue)) { + producers.put(value, appliedTransform); + } } - value.recordAsOutput(getAppliedTransform(producer)); - } - - private String genStepName() { - return String.format("s%s", numTransforms++); - } - - - /** - * Returns a mapping of each fully-expanded {@link PValue} to each - * {@link AppliedPTransform} that consumes it. For each AppliedPTransform in the collection - * returned from {@code getValueToCustomers().get(PValue)}, - * {@code AppliedPTransform#getInput().expand()} will contain the argument {@link PValue}. - */ - public Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> getValueToConsumers() { - checkState( - finalized, - "Can't call getValueToConsumers before the Pipeline has been completely traversed"); - - return valueToConsumers; } - /** - * Returns the mapping for each {@link AppliedPTransform} in the {@link Pipeline} to a unique step - * name. - */ - public Map<AppliedPTransform<?, ?, ?>, String> getStepNames() { - checkState( - finalized, "Can't call getStepNames before the Pipeline has been completely traversed"); - - return stepNames; - } - - /** - * Returns the root transforms of the {@link Pipeline}. A root {@link AppliedPTransform} consumes - * a {@link PInput} where the {@link PInput#expand()} returns an empty collection. - */ - public Collection<AppliedPTransform<?, ?, ?>> getRootTransforms() { - checkState( - finalized, - "Can't call getRootTransforms before the Pipeline has been completely traversed"); - - return rootTransforms; + private AppliedPTransform<?, ?, ?> getAppliedTransform(TransformHierarchy.Node node) { + @SuppressWarnings({"rawtypes", "unchecked"}) + AppliedPTransform<?, ?, ?> application = AppliedPTransform.of( + node.getFullName(), node.getInput(), node.getOutput(), (PTransform) node.getTransform()); + return application; } - /** - * Returns all of the {@link PCollectionView PCollectionViews} contained in the visited - * {@link Pipeline}. - */ - public Collection<PCollectionView<?>> getViews() { - checkState(finalized, "Can't call getViews before the Pipeline has been completely traversed"); - - return views; + private String genStepName() { + return String.format("s%s", numTransforms++); } /** @@ -163,11 +125,21 @@ public class ConsumerTrackingPipelineVisitor extends PipelineVisitor.Defaults { * {@link PValue PValues} should be finalized by the {@link PipelineRunner} before the * {@link Pipeline} is executed. */ - public Set<PValue> getUnfinalizedPValues() { + public void finishSpecifyingRemainder() { checkState( finalized, - "Can't call getUnfinalizedPValues before the Pipeline has been completely traversed"); + "Can't call finishSpecifyingRemainder before the Pipeline has been completely traversed"); + for (PValue unfinalized : toFinalize) { + unfinalized.finishSpecifying(); + } + } - return toFinalize; + /** + * Get the graph constructed by this {@link ConsumerTrackingPipelineVisitor}, which provides + * lookups for producers and consumers of {@link PValue PValues}. + */ + public DirectGraph getGraph() { + checkState(finalized, "Can't get a graph before the Pipeline has been completely traversed"); + return DirectGraph.create(producers, primitiveConsumers, views, rootTransforms, stepNames); } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8162cd29/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java new file mode 100644 index 0000000..f208f6e --- /dev/null +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import com.google.common.collect.ListMultimap; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.transforms.AppliedPTransform; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.PValue; + +/** + * Methods for interacting with the underlying structure of a {@link Pipeline} that is being + * executed with the {@link DirectRunner}. + */ +class DirectGraph { + private final Map<POutput, AppliedPTransform<?, ?, ?>> producers; + private final ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers; + private final Set<PCollectionView<?>> views; + + private final Set<AppliedPTransform<?, ?, ?>> rootTransforms; + private final Map<AppliedPTransform<?, ?, ?>, String> stepNames; + + public static DirectGraph create( + Map<POutput, AppliedPTransform<?, ?, ?>> producers, + ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers, + Set<PCollectionView<?>> views, + Set<AppliedPTransform<?, ?, ?>> rootTransforms, + Map<AppliedPTransform<?, ?, ?>, String> stepNames) { + return new DirectGraph(producers, primitiveConsumers, views, rootTransforms, stepNames); + } + + private DirectGraph( + Map<POutput, AppliedPTransform<?, ?, ?>> producers, + ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers, + Set<PCollectionView<?>> views, + Set<AppliedPTransform<?, ?, ?>> rootTransforms, + Map<AppliedPTransform<?, ?, ?>, String> stepNames) { + this.producers = producers; + this.primitiveConsumers = primitiveConsumers; + this.views = views; + this.rootTransforms = rootTransforms; + this.stepNames = stepNames; + } + + public AppliedPTransform<?, ?, ?> getProducer(PValue produced) { + return producers.get(produced); + } + + public List<AppliedPTransform<?, ?, ?>> getPrimitiveConsumers(PValue consumed) { + return primitiveConsumers.get(consumed); + } + + public Set<AppliedPTransform<?, ?, ?>> getRootTransforms() { + return rootTransforms; + } + + public Set<PCollectionView<?>> getViews() { + return views; + } + + public String getStepName(AppliedPTransform<?, ?, ?> step) { + return stepNames.get(step); + } + + public Collection<AppliedPTransform<?, ?, ?>> getPrimitiveTransforms() { + return stepNames.keySet(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8162cd29/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index 82de9ab..0ad5836 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -62,7 +62,6 @@ import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; -import org.apache.beam.sdk.values.PValue; import org.joda.time.Duration; import org.joda.time.Instant; @@ -198,18 +197,18 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { enum Enforcement { ENCODABILITY { @Override - public boolean appliesTo(PTransform<?, ?> transform) { + public boolean appliesTo(PCollection<?> collection, DirectGraph graph) { return true; } }, IMMUTABILITY { @Override - public boolean appliesTo(PTransform<?, ?> transform) { - return CONTAINS_UDF.contains(transform.getClass()); + public boolean appliesTo(PCollection<?> collection, DirectGraph graph) { + return CONTAINS_UDF.contains(graph.getProducer(collection).getTransform().getClass()); } }; - public abstract boolean appliesTo(PTransform<?, ?> transform); + public abstract boolean appliesTo(PCollection<?> collection, DirectGraph graph); //////////////////////////////////////////////////////////////////////////////////////////////// // Utilities for creating enforcements @@ -224,13 +223,13 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { return Collections.unmodifiableSet(enabled); } - public static BundleFactory bundleFactoryFor(Set<Enforcement> enforcements) { + public static BundleFactory bundleFactoryFor(Set<Enforcement> enforcements, DirectGraph graph) { BundleFactory bundleFactory = enforcements.contains(Enforcement.ENCODABILITY) ? CloningBundleFactory.create() : ImmutableListBundleFactory.create(); if (enforcements.contains(Enforcement.IMMUTABILITY)) { - bundleFactory = ImmutabilityCheckingBundleFactory.create(bundleFactory); + bundleFactory = ImmutabilityCheckingBundleFactory.create(bundleFactory, graph); } return bundleFactory; } @@ -301,9 +300,8 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { MetricsEnvironment.setMetricsSupported(true); ConsumerTrackingPipelineVisitor consumerTrackingVisitor = new ConsumerTrackingPipelineVisitor(); pipeline.traverseTopologically(consumerTrackingVisitor); - for (PValue unfinalized : consumerTrackingVisitor.getUnfinalizedPValues()) { - unfinalized.finishSpecifying(); - } + consumerTrackingVisitor.finishSpecifyingRemainder(); + @SuppressWarnings("rawtypes") KeyedPValueTrackingVisitor keyedPValueVisitor = KeyedPValueTrackingVisitor.create( @@ -315,28 +313,25 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { DisplayDataValidator.validatePipeline(pipeline); + DirectGraph graph = consumerTrackingVisitor.getGraph(); EvaluationContext context = EvaluationContext.create( getPipelineOptions(), clockSupplier.get(), - Enforcement.bundleFactoryFor(enabledEnforcements), - consumerTrackingVisitor.getRootTransforms(), - consumerTrackingVisitor.getValueToConsumers(), - consumerTrackingVisitor.getStepNames(), - consumerTrackingVisitor.getViews()); + Enforcement.bundleFactoryFor(enabledEnforcements, graph), + graph); RootProviderRegistry rootInputProvider = RootProviderRegistry.defaultRegistry(context); TransformEvaluatorRegistry registry = TransformEvaluatorRegistry.defaultRegistry(context); PipelineExecutor executor = ExecutorServiceParallelExecutor.create( - options.getTargetParallelism(), - consumerTrackingVisitor.getValueToConsumers(), + options.getTargetParallelism(), graph, keyedPValueVisitor.getKeyedPValues(), rootInputProvider, registry, Enforcement.defaultModelEnforcements(enabledEnforcements), context); - executor.start(consumerTrackingVisitor.getRootTransforms()); + executor.start(graph.getRootTransforms()); Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = pipeline.getAggregatorSteps(); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8162cd29/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java index 201aaed..b5a23d7 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java @@ -74,8 +74,10 @@ import org.joda.time.Instant; * can be executed. */ class EvaluationContext { - /** The step name for each {@link AppliedPTransform} in the {@link Pipeline}. */ - private final Map<AppliedPTransform<?, ?, ?>, String> stepNames; + /** + * The graph representing this {@link Pipeline}. + */ + private final DirectGraph graph; /** The options that were used to create this {@link Pipeline}. */ private final DirectOptions options; @@ -99,36 +101,19 @@ class EvaluationContext { private final DirectMetrics metrics; public static EvaluationContext create( - DirectOptions options, - Clock clock, - BundleFactory bundleFactory, - Collection<AppliedPTransform<?, ?, ?>> rootTransforms, - Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers, - Map<AppliedPTransform<?, ?, ?>, String> stepNames, - Collection<PCollectionView<?>> views) { - return new EvaluationContext( - options, clock, bundleFactory, rootTransforms, valueToConsumers, stepNames, views); + DirectOptions options, Clock clock, BundleFactory bundleFactory, DirectGraph graph) { + return new EvaluationContext(options, clock, bundleFactory, graph); } private EvaluationContext( - DirectOptions options, - Clock clock, - BundleFactory bundleFactory, - Collection<AppliedPTransform<?, ?, ?>> rootTransforms, - Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers, - Map<AppliedPTransform<?, ?, ?>, String> stepNames, - Collection<PCollectionView<?>> views) { + DirectOptions options, Clock clock, BundleFactory bundleFactory, DirectGraph graph) { this.options = checkNotNull(options); this.clock = clock; this.bundleFactory = checkNotNull(bundleFactory); - checkNotNull(rootTransforms); - checkNotNull(valueToConsumers); - checkNotNull(stepNames); - checkNotNull(views); - this.stepNames = stepNames; + this.graph = checkNotNull(graph); - this.watermarkManager = WatermarkManager.create(clock, rootTransforms, valueToConsumers); - this.sideInputContainer = SideInputContainer.create(this, views); + this.watermarkManager = WatermarkManager.create(clock, graph); + this.sideInputContainer = SideInputContainer.create(this, graph.getViews()); this.applicationStateInternals = new ConcurrentHashMap<>(); this.mergedAggregators = AggregatorContainer.create(); @@ -211,7 +196,7 @@ class EvaluationContext { ImmutableList.Builder<CommittedBundle<?>> completed = ImmutableList.builder(); for (UncommittedBundle<?> inProgress : bundles) { AppliedPTransform<?, ?, ?> producing = - inProgress.getPCollection().getProducingTransformInternal(); + graph.getProducer(inProgress.getPCollection()); TransformWatermarks watermarks = watermarkManager.getWatermarks(producing); CommittedBundle<?> committed = inProgress.commit(watermarks.getSynchronizedProcessingOutputTime()); @@ -225,7 +210,7 @@ class EvaluationContext { } private void fireAllAvailableCallbacks() { - for (AppliedPTransform<?, ?, ?> transform : stepNames.keySet()) { + for (AppliedPTransform<?, ?, ?> transform : graph.getPrimitiveTransforms()) { fireAvailableCallbacks(transform); } } @@ -290,10 +275,10 @@ class EvaluationContext { BoundedWindow window, WindowingStrategy<?, ?> windowingStrategy, Runnable runnable) { - AppliedPTransform<?, ?, ?> producing = getProducing(value); + AppliedPTransform<?, ?, ?> producing = graph.getProducer(value); callbackExecutor.callOnGuaranteedFiring(producing, window, windowingStrategy, runnable); - fireAvailableCallbacks(lookupProducing(value)); + fireAvailableCallbacks(producing); } /** @@ -311,22 +296,6 @@ class EvaluationContext { fireAvailableCallbacks(producing); } - private AppliedPTransform<?, ?, ?> getProducing(PValue value) { - if (value.getProducingTransformInternal() != null) { - return value.getProducingTransformInternal(); - } - return lookupProducing(value); - } - - private AppliedPTransform<?, ?, ?> lookupProducing(PValue value) { - for (AppliedPTransform<?, ?, ?> transform : stepNames.keySet()) { - if (transform.getOutput().equals(value) || transform.getOutput().expand().contains(value)) { - return transform; - } - } - return null; - } - /** * Get the options used by this {@link Pipeline}. */ @@ -347,18 +316,17 @@ class EvaluationContext { watermarkManager.getWatermarks(application)); } - /** - * Get all of the steps used in this {@link Pipeline}. - */ - public Collection<AppliedPTransform<?, ?, ?>> getSteps() { - return stepNames.keySet(); - } /** * Get the Step Name for the provided application. */ - public String getStepName(AppliedPTransform<?, ?, ?> application) { - return stepNames.get(application); + String getStepName(AppliedPTransform<?, ?, ?> application) { + return graph.getStepName(application); + } + + /** Returns all of the steps in this {@link Pipeline}. */ + Collection<AppliedPTransform<?, ?, ?>> getSteps() { + return graph.getPrimitiveTransforms(); } /** @@ -450,7 +418,7 @@ class EvaluationContext { * Returns true if all steps are done. */ public boolean isDone() { - for (AppliedPTransform<?, ?, ?> transform : stepNames.keySet()) { + for (AppliedPTransform<?, ?, ?> transform : graph.getPrimitiveTransforms()) { if (!isDone(transform)) { return false; } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8162cd29/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java index b7908c5..929d09d 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java @@ -69,7 +69,7 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { private final int targetParallelism; private final ExecutorService executorService; - private final Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers; + private final DirectGraph graph; private final Set<PValue> keyedPValues; private final RootProviderRegistry rootProviderRegistry; private final TransformEvaluatorRegistry registry; @@ -104,7 +104,7 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { public static ExecutorServiceParallelExecutor create( int targetParallelism, - Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers, + DirectGraph graph, Set<PValue> keyedPValues, RootProviderRegistry rootProviderRegistry, TransformEvaluatorRegistry registry, @@ -114,7 +114,7 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { EvaluationContext context) { return new ExecutorServiceParallelExecutor( targetParallelism, - valueToConsumers, + graph, keyedPValues, rootProviderRegistry, registry, @@ -124,7 +124,7 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { private ExecutorServiceParallelExecutor( int targetParallelism, - Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers, + DirectGraph graph, Set<PValue> keyedPValues, RootProviderRegistry rootProviderRegistry, TransformEvaluatorRegistry registry, @@ -133,7 +133,7 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { EvaluationContext context) { this.targetParallelism = targetParallelism; this.executorService = Executors.newFixedThreadPool(targetParallelism); - this.valueToConsumers = valueToConsumers; + this.graph = graph; this.keyedPValues = keyedPValues; this.rootProviderRegistry = rootProviderRegistry; this.registry = registry; @@ -273,8 +273,9 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { CommittedBundle<?> inputBundle, TransformResult<?> result) { CommittedResult committedResult = evaluationContext.handleResult(inputBundle, timers, result); for (CommittedBundle<?> outputBundle : committedResult.getOutputs()) { - allUpdates.offer(ExecutorUpdate.fromBundle(outputBundle, - valueToConsumers.get(outputBundle.getPCollection()))); + allUpdates.offer( + ExecutorUpdate.fromBundle( + outputBundle, graph.getPrimitiveConsumers(outputBundle.getPCollection()))); } CommittedBundle<?> unprocessedInputs = committedResult.getUnprocessedInputs(); if (unprocessedInputs != null && !Iterables.isEmpty(unprocessedInputs.getElements())) { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8162cd29/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactory.java index 4f72f68..8d77e25 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactory.java @@ -46,17 +46,20 @@ import org.joda.time.Instant; */ class ImmutabilityCheckingBundleFactory implements BundleFactory { /** - * Create a new {@link ImmutabilityCheckingBundleFactory} that uses the underlying - * {@link BundleFactory} to create the output bundle. + * Create a new {@link ImmutabilityCheckingBundleFactory} that uses the underlying {@link + * BundleFactory} to create the output bundle. */ - public static ImmutabilityCheckingBundleFactory create(BundleFactory underlying) { - return new ImmutabilityCheckingBundleFactory(underlying); + public static ImmutabilityCheckingBundleFactory create( + BundleFactory underlying, DirectGraph graph) { + return new ImmutabilityCheckingBundleFactory(underlying, graph); } private final BundleFactory underlying; + private final DirectGraph graph; - private ImmutabilityCheckingBundleFactory(BundleFactory underlying) { + private ImmutabilityCheckingBundleFactory(BundleFactory underlying, DirectGraph graph) { this.underlying = checkNotNull(underlying); + this.graph = graph; } /** @@ -72,7 +75,7 @@ class ImmutabilityCheckingBundleFactory implements BundleFactory { @Override public <T> UncommittedBundle<T> createBundle(PCollection<T> output) { - if (Enforcement.IMMUTABILITY.appliesTo(output.getProducingTransformInternal().getTransform())) { + if (Enforcement.IMMUTABILITY.appliesTo(output, graph)) { return new ImmutabilityEnforcingBundle<>(underlying.createBundle(output)); } return underlying.createBundle(output); @@ -81,13 +84,13 @@ class ImmutabilityCheckingBundleFactory implements BundleFactory { @Override public <K, T> UncommittedBundle<T> createKeyedBundle( StructuralKey<K> key, PCollection<T> output) { - if (Enforcement.IMMUTABILITY.appliesTo(output.getProducingTransformInternal().getTransform())) { + if (Enforcement.IMMUTABILITY.appliesTo(output, graph)) { return new ImmutabilityEnforcingBundle<>(underlying.createKeyedBundle(key, output)); } return underlying.createKeyedBundle(key, output); } - private static class ImmutabilityEnforcingBundle<T> implements UncommittedBundle<T> { + private class ImmutabilityEnforcingBundle<T> implements UncommittedBundle<T> { private final UncommittedBundle<T> underlying; private final SetMultimap<WindowedValue<T>, MutationDetector> mutationDetectors; private Coder<T> coder; @@ -125,7 +128,7 @@ class ImmutabilityCheckingBundleFactory implements BundleFactory { String.format( "PTransform %s mutated value %s after it was output (new value was %s)." + " Values must not be mutated in any way after being output.", - underlying.getPCollection().getProducingTransformInternal().getFullName(), + graph.getProducer(underlying.getPCollection()).getFullName(), exn.getSavedValue(), exn.getNewValue()), exn.getSavedValue(), http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8162cd29/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java index a53c11c..247b1cc 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java @@ -669,10 +669,10 @@ public class WatermarkManager { private final Clock clock; /** - * A map from each {@link PCollection} to all {@link AppliedPTransform PTransform applications} - * that consume that {@link PCollection}. + * The {@link DirectGraph} representing the {@link Pipeline} this {@link WatermarkManager} tracks + * watermarks for. */ - private final Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> consumers; + private final DirectGraph graph; /** * The input and output watermark of each {@link AppliedPTransform}. @@ -697,27 +697,21 @@ public class WatermarkManager { private final Set<AppliedPTransform<?, ?, ?>> pendingRefreshes; /** - * Creates a new {@link WatermarkManager}. All watermarks within the newly created - * {@link WatermarkManager} start at {@link BoundedWindow#TIMESTAMP_MIN_VALUE}, the - * minimum watermark, with no watermark holds or pending elements. + * Creates a new {@link WatermarkManager}. All watermarks within the newly created {@link + * WatermarkManager} start at {@link BoundedWindow#TIMESTAMP_MIN_VALUE}, the minimum watermark, + * with no watermark holds or pending elements. * - * @param rootTransforms the root-level transforms of the {@link Pipeline} - * @param consumers a mapping between each {@link PCollection} in the {@link Pipeline} to the - * transforms that consume it as a part of their input + * @param clock the clock to use to determine processing time + * @param graph the graph representing this pipeline */ - public static WatermarkManager create( - Clock clock, - Collection<AppliedPTransform<?, ?, ?>> rootTransforms, - Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> consumers) { - return new WatermarkManager(clock, rootTransforms, consumers); + public static WatermarkManager create(Clock clock, DirectGraph graph) { + return new WatermarkManager(clock, graph); } - private WatermarkManager( - Clock clock, - Collection<AppliedPTransform<?, ?, ?>> rootTransforms, - Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> consumers) { + private WatermarkManager(Clock clock, DirectGraph graph) { this.clock = clock; - this.consumers = consumers; + this.graph = graph; + this.pendingUpdates = new ConcurrentLinkedQueue<>(); this.refreshLock = new ReentrantLock(); @@ -725,13 +719,11 @@ public class WatermarkManager { transformToWatermarks = new HashMap<>(); - for (AppliedPTransform<?, ?, ?> rootTransform : rootTransforms) { + for (AppliedPTransform<?, ?, ?> rootTransform : graph.getRootTransforms()) { getTransformWatermark(rootTransform); } - for (Collection<AppliedPTransform<?, ?, ?>> intermediateTransforms : consumers.values()) { - for (AppliedPTransform<?, ?, ?> transform : intermediateTransforms) { - getTransformWatermark(transform); - } + for (AppliedPTransform<?, ?, ?> primitiveTransform : graph.getPrimitiveTransforms()) { + getTransformWatermark(primitiveTransform); } } @@ -769,8 +761,7 @@ public class WatermarkManager { } for (PValue pvalue : inputs) { Watermark producerOutputWatermark = - getTransformWatermark(pvalue.getProducingTransformInternal()) - .synchronizedProcessingOutputWatermark; + getTransformWatermark(graph.getProducer(pvalue)).synchronizedProcessingOutputWatermark; inputWmsBuilder.add(producerOutputWatermark); } return inputWmsBuilder.build(); @@ -784,7 +775,7 @@ public class WatermarkManager { } for (PValue pvalue : inputs) { Watermark producerOutputWatermark = - getTransformWatermark(pvalue.getProducingTransformInternal()).outputWatermark; + getTransformWatermark(graph.getProducer(pvalue)).outputWatermark; inputWatermarksBuilder.add(producerOutputWatermark); } List<Watermark> inputCollectionWatermarks = inputWatermarksBuilder.build(); @@ -920,7 +911,8 @@ public class WatermarkManager { // do not share a Mutex within this call and thus can be interleaved with external calls to // refresh. for (CommittedBundle<?> bundle : result.getOutputs()) { - for (AppliedPTransform<?, ?, ?> consumer : consumers.get(bundle.getPCollection())) { + for (AppliedPTransform<?, ?, ?> consumer : + graph.getPrimitiveConsumers(bundle.getPCollection())) { TransformWatermarks watermarks = transformToWatermarks.get(consumer); watermarks.addPending(bundle); } @@ -968,7 +960,7 @@ public class WatermarkManager { if (updateResult.isAdvanced()) { Set<AppliedPTransform<?, ?, ?>> additionalRefreshes = new HashSet<>(); for (PValue outputPValue : toRefresh.getOutput().expand()) { - additionalRefreshes.addAll(consumers.get(outputPValue)); + additionalRefreshes.addAll(graph.getPrimitiveConsumers(outputPValue)); } return additionalRefreshes; } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8162cd29/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitorTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitorTest.java index f7f4b71..02fe007 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitorTest.java @@ -18,6 +18,8 @@ package org.apache.beam.runners.direct; import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertThat; import java.io.Serializable; @@ -36,7 +38,6 @@ import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.PInput; -import org.apache.beam.sdk.values.PValue; import org.hamcrest.Matchers; import org.junit.Rule; import org.junit.Test; @@ -72,7 +73,7 @@ public class ConsumerTrackingPipelineVisitorTest implements Serializable { p.apply("singletonCreate", Create.<Object>of(1, 2, 3)).apply(View.<Object>asSingleton()); p.traverseTopologically(visitor); assertThat( - visitor.getViews(), + visitor.getGraph().getViews(), Matchers.<PCollectionView<?>>containsInAnyOrder(listView, singletonView)); } @@ -83,7 +84,7 @@ public class ConsumerTrackingPipelineVisitorTest implements Serializable { PCollection<Long> unCounted = p.apply(CountingInput.unbounded()); p.traverseTopologically(visitor); assertThat( - visitor.getRootTransforms(), + visitor.getGraph().getRootTransforms(), Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( created.getProducingTransformInternal(), counted.getProducingTransformInternal(), @@ -96,7 +97,7 @@ public class ConsumerTrackingPipelineVisitorTest implements Serializable { PCollectionList.<String>empty(p).apply(Flatten.<String>pCollections()); p.traverseTopologically(visitor); assertThat( - visitor.getRootTransforms(), + visitor.getGraph().getRootTransforms(), Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( empty.getProducingTransformInternal())); } @@ -121,15 +122,15 @@ public class ConsumerTrackingPipelineVisitorTest implements Serializable { p.traverseTopologically(visitor); assertThat( - visitor.getValueToConsumers().get(created), + visitor.getGraph().getPrimitiveConsumers(created), Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( transformed.getProducingTransformInternal(), flattened.getProducingTransformInternal())); assertThat( - visitor.getValueToConsumers().get(transformed), + visitor.getGraph().getPrimitiveConsumers(transformed), Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( flattened.getProducingTransformInternal())); - assertThat(visitor.getValueToConsumers().get(flattened), emptyIterable()); + assertThat(visitor.getGraph().getPrimitiveConsumers(flattened), emptyIterable()); } @Test @@ -142,11 +143,11 @@ public class ConsumerTrackingPipelineVisitorTest implements Serializable { p.traverseTopologically(visitor); assertThat( - visitor.getValueToConsumers().get(created), + visitor.getGraph().getPrimitiveConsumers(created), Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( flattened.getProducingTransformInternal(), flattened.getProducingTransformInternal())); - assertThat(visitor.getValueToConsumers().get(flattened), emptyIterable()); + assertThat(visitor.getGraph().getPrimitiveConsumers(flattened), emptyIterable()); } @Test @@ -163,32 +164,11 @@ public class ConsumerTrackingPipelineVisitorTest implements Serializable { } })); - p.traverseTopologically(visitor); - assertThat(visitor.getUnfinalizedPValues(), Matchers.<PValue>contains(transformed)); - } - - @Test - public void getUnfinalizedPValuesEmpty() { - p.apply(Create.of("1", "2", "3")) - .apply( - ParDo.of( - new DoFn<String, String>() { - @ProcessElement - public void processElement(DoFn<String, String>.ProcessContext c) - throws Exception { - c.output(Integer.toString(c.element().length())); - } - })) - .apply( - new PTransform<PInput, PDone>() { - @Override - public PDone apply(PInput input) { - return PDone.in(input.getPipeline()); - } - }); + assertThat(transformed.isFinishedSpecifyingInternal(), is(false)); p.traverseTopologically(visitor); - assertThat(visitor.getUnfinalizedPValues(), emptyIterable()); + visitor.finishSpecifyingRemainder(); + assertThat(transformed.isFinishedSpecifyingInternal(), is(true)); } @Test @@ -214,18 +194,12 @@ public class ConsumerTrackingPipelineVisitorTest implements Serializable { }); p.traverseTopologically(visitor); - assertThat( - visitor.getStepNames(), - Matchers.<AppliedPTransform<?, ?, ?>, String>hasEntry( - created.getProducingTransformInternal(), "s0")); - assertThat( - visitor.getStepNames(), - Matchers.<AppliedPTransform<?, ?, ?>, String>hasEntry( - transformed.getProducingTransformInternal(), "s1")); - assertThat( - visitor.getStepNames(), - Matchers.<AppliedPTransform<?, ?, ?>, String>hasEntry( - finished.getProducingTransformInternal(), "s2")); + DirectGraph graph = visitor.getGraph(); + assertThat(graph.getStepName(graph.getProducer(created)), equalTo("s0")); + assertThat(graph.getStepName(graph.getProducer(transformed)), equalTo("s1")); + // finished doesn't have a producer, because it's not a PValue. + // TODO: Demonstrate that PCollectionList/Tuple and other composite PValues are either safe to + // use, or make them so. } @Test @@ -248,40 +222,18 @@ public class ConsumerTrackingPipelineVisitorTest implements Serializable { } @Test - public void getRootTransformsWithoutVisitingThrows() { - thrown.expect(IllegalStateException.class); - thrown.expectMessage("completely traversed"); - thrown.expectMessage("getRootTransforms"); - visitor.getRootTransforms(); - } - @Test - public void getStepNamesWithoutVisitingThrows() { - thrown.expect(IllegalStateException.class); - thrown.expectMessage("completely traversed"); - thrown.expectMessage("getStepNames"); - visitor.getStepNames(); - } - @Test - public void getUnfinalizedPValuesWithoutVisitingThrows() { - thrown.expect(IllegalStateException.class); - thrown.expectMessage("completely traversed"); - thrown.expectMessage("getUnfinalizedPValues"); - visitor.getUnfinalizedPValues(); - } - - @Test - public void getValueToConsumersWithoutVisitingThrows() { + public void getGraphWithoutVisitingThrows() { thrown.expect(IllegalStateException.class); thrown.expectMessage("completely traversed"); - thrown.expectMessage("getValueToConsumers"); - visitor.getValueToConsumers(); + thrown.expectMessage("get a graph"); + visitor.getGraph(); } @Test - public void getViewsWithoutVisitingThrows() { + public void finishSpecifyingRemainderWithoutVisitingThrows() { thrown.expect(IllegalStateException.class); thrown.expectMessage("completely traversed"); - thrown.expectMessage("getViews"); - visitor.getViews(); + thrown.expectMessage("finishSpecifyingRemainder"); + visitor.finishSpecifyingRemainder(); } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8162cd29/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java index 9a3959d..1c2bf14 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java @@ -29,7 +29,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import java.util.Collection; import java.util.Collections; -import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import org.apache.beam.runners.direct.DirectExecutionContext.DirectStepContext; @@ -67,7 +66,6 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.PValue; import org.hamcrest.Matchers; import org.joda.time.Instant; import org.junit.Before; @@ -87,10 +85,9 @@ public class EvaluationContextTest { private PCollection<KV<String, Integer>> downstream; private PCollectionView<Iterable<Integer>> view; private PCollection<Long> unbounded; - private Collection<AppliedPTransform<?, ?, ?>> rootTransforms; - private Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers; private BundleFactory bundleFactory; + private DirectGraph graph; @Before public void setup() { @@ -106,20 +103,12 @@ public class EvaluationContextTest { ConsumerTrackingPipelineVisitor cVis = new ConsumerTrackingPipelineVisitor(); p.traverseTopologically(cVis); - rootTransforms = cVis.getRootTransforms(); - valueToConsumers = cVis.getValueToConsumers(); bundleFactory = ImmutableListBundleFactory.create(); - + graph = cVis.getGraph(); context = EvaluationContext.create( - runner.getPipelineOptions(), - NanosOffsetClock.create(), - ImmutableListBundleFactory.create(), - rootTransforms, - valueToConsumers, - cVis.getStepNames(), - cVis.getViews()); + runner.getPipelineOptions(), NanosOffsetClock.create(), bundleFactory, graph); } @Test @@ -427,13 +416,13 @@ public class EvaluationContextTest { @Test public void isDoneWithUnboundedPCollectionAndNotShutdown() { context.getPipelineOptions().setShutdownUnboundedProducersWithMaxWatermark(false); - assertThat(context.isDone(unbounded.getProducingTransformInternal()), is(false)); + assertThat(context.isDone(graph.getProducer(unbounded)), is(false)); context.handleResult( null, ImmutableList.<TimerData>of(), - StepTransformResult.withoutHold(unbounded.getProducingTransformInternal()).build()); - assertThat(context.isDone(unbounded.getProducingTransformInternal()), is(false)); + StepTransformResult.withoutHold(graph.getProducer(unbounded)).build()); + assertThat(context.isDone(graph.getProducer(unbounded)), is(false)); } @Test @@ -472,7 +461,7 @@ public class EvaluationContextTest { StepTransformResult.withoutHold(unbounded.getProducingTransformInternal()).build()); assertThat(context.isDone(), is(false)); - for (AppliedPTransform<?, ?, ?> consumers : valueToConsumers.get(created)) { + for (AppliedPTransform<?, ?, ?> consumers : graph.getPrimitiveConsumers(created)) { context.handleResult( committedBundle, ImmutableList.<TimerData>of(), http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8162cd29/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactoryTest.java index ea44125..e7e1e62 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactoryTest.java @@ -56,7 +56,11 @@ public class ImmutabilityCheckingBundleFactoryTest { TestPipeline p = TestPipeline.create(); created = p.apply(Create.<byte[]>of().withCoder(ByteArrayCoder.of())); transformed = created.apply(ParDo.of(new IdentityDoFn<byte[]>())); - factory = ImmutabilityCheckingBundleFactory.create(ImmutableListBundleFactory.create()); + ConsumerTrackingPipelineVisitor visitor = new ConsumerTrackingPipelineVisitor(); + p.traverseTopologically(visitor); + factory = + ImmutabilityCheckingBundleFactory.create( + ImmutableListBundleFactory.create(), visitor.getGraph()); } @Test http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8162cd29/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkManagerTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkManagerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkManagerTest.java index 2e8ab84..5cde4d6 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkManagerTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkManagerTest.java @@ -94,6 +94,7 @@ public class WatermarkManagerTest implements Serializable { private transient WatermarkManager manager; private transient BundleFactory bundleFactory; + private DirectGraph graph; @Before public void setup() { @@ -139,8 +140,11 @@ public class WatermarkManagerTest implements Serializable { consumers.put(flattened, Collections.<AppliedPTransform<?, ?, ?>>emptyList()); clock = MockClock.fromInstant(new Instant(1000)); + ConsumerTrackingPipelineVisitor visitor = new ConsumerTrackingPipelineVisitor(); + p.traverseTopologically(visitor); + graph = visitor.getGraph(); - manager = WatermarkManager.create(clock, rootTransforms, consumers); + manager = WatermarkManager.create(clock, graph); bundleFactory = ImmutableListBundleFactory.create(); } @@ -305,20 +309,13 @@ public class WatermarkManagerTest implements Serializable { PCollection<Integer> created = p.apply(Create.of(1, 2, 3)); PCollection<Integer> multiConsumer = PCollectionList.of(created).and(created).apply(Flatten.<Integer>pCollections()); - AppliedPTransform<?, ?, ?> theFlatten = multiConsumer.getProducingTransformInternal(); + ConsumerTrackingPipelineVisitor trackingVisitor = new ConsumerTrackingPipelineVisitor(); + p.traverseTopologically(trackingVisitor); + DirectGraph graph = trackingVisitor.getGraph(); - Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers = - ImmutableMap.<PValue, Collection<AppliedPTransform<?, ?, ?>>>builder() - .put(created, ImmutableList.<AppliedPTransform<?, ?, ?>>of(theFlatten, theFlatten)) - .put(multiConsumer, Collections.<AppliedPTransform<?, ?, ?>>emptyList()) - .build(); + AppliedPTransform<?, ?, ?> theFlatten = graph.getProducer(multiConsumer); - WatermarkManager tstMgr = - WatermarkManager.create( - clock, - Collections.<AppliedPTransform<?, ?, ?>>singleton( - created.getProducingTransformInternal()), - valueToConsumers); + WatermarkManager tstMgr = WatermarkManager.create(clock, graph); CommittedBundle<Void> root = bundleFactory .<Void>createRootBundle()