Implement InProcessPipelineRunner#run Appropriately construct an evaluation context and executor, and start the pipeline when run is called.
Implement InProcessPipelineResult. Apply PTransform overrides. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/158f9f8d Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/158f9f8d Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/158f9f8d Branch: refs/heads/master Commit: 158f9f8d41c63f5a002c6187f4f05f169579dd6d Parents: 5ecb7aa Author: Thomas Groh <tg...@google.com> Authored: Fri Feb 26 17:30:13 2016 -0800 Committer: Maximilian Michels <m...@apache.org> Committed: Wed Mar 23 19:27:51 2016 +0100 ---------------------------------------------------------------------- .../CachedThreadPoolExecutorServiceFactory.java | 42 ++++ .../ConsumerTrackingPipelineVisitor.java | 173 ++++++++++++++ .../inprocess/ExecutorServiceFactory.java | 32 +++ .../ExecutorServiceParallelExecutor.java | 2 +- .../inprocess/GroupByKeyEvaluatorFactory.java | 4 +- .../inprocess/InProcessPipelineOptions.java | 56 +++++ .../inprocess/InProcessPipelineRunner.java | 228 +++++++++++++++--- .../inprocess/KeyedPValueTrackingVisitor.java | 95 ++++++++ .../ConsumerTrackingPipelineVisitorTest.java | 233 +++++++++++++++++++ .../inprocess/InProcessPipelineRunnerTest.java | 77 ++++++ .../KeyedPValueTrackingVisitorTest.java | 189 +++++++++++++++ 11 files changed, 1101 insertions(+), 30 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java new file mode 100644 index 0000000..3350d2b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** + * A {@link ExecutorServiceFactory} that produces cached thread pools via + * {@link Executors#newCachedThreadPool()}. + */ +class CachedThreadPoolExecutorServiceFactory + implements DefaultValueFactory<ExecutorServiceFactory>, ExecutorServiceFactory { + private static final CachedThreadPoolExecutorServiceFactory INSTANCE = + new CachedThreadPoolExecutorServiceFactory(); + + @Override + public ExecutorServiceFactory create(PipelineOptions options) { + return INSTANCE; + } + + @Override + public ExecutorService create() { + return Executors.newCachedThreadPool(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java new file mode 100644 index 0000000..c602b23 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java @@ -0,0 +1,173 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.PValue; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** + * Tracks the {@link AppliedPTransform AppliedPTransforms} that consume each {@link PValue} in the + * {@link Pipeline}. This is used to schedule consuming {@link PTransform PTransforms} to consume + * input after the upstream transform has produced and committed output. + */ +public class ConsumerTrackingPipelineVisitor implements PipelineVisitor { + private Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers = new HashMap<>(); + private Collection<AppliedPTransform<?, ?, ?>> rootTransforms = new ArrayList<>(); + private Collection<PCollectionView<?>> views = new ArrayList<>(); + private Map<AppliedPTransform<?, ?, ?>, String> stepNames = new HashMap<>(); + private Set<PValue> toFinalize = new HashSet<>(); + private int numTransforms = 0; + private boolean finalized = false; + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + checkState( + !finalized, + "Attempting to traverse a pipeline (node %s) with a %s " + + "which has already visited a Pipeline and is finalized", + node.getFullName(), + ConsumerTrackingPipelineVisitor.class.getSimpleName()); + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + checkState( + !finalized, + "Attempting to traverse a pipeline (node %s) with a %s which is already finalized", + node.getFullName(), + ConsumerTrackingPipelineVisitor.class.getSimpleName()); + if (node.isRootNode()) { + finalized = true; + } + } + + @Override + public void visitTransform(TransformTreeNode node) { + toFinalize.removeAll(node.getInput().expand()); + AppliedPTransform<?, ?, ?> appliedTransform = getAppliedTransform(node); + if (node.getInput().expand().isEmpty()) { + rootTransforms.add(appliedTransform); + } else { + for (PValue value : node.getInput().expand()) { + valueToConsumers.get(value).add(appliedTransform); + stepNames.put(appliedTransform, genStepName()); + } + } + } + + private AppliedPTransform<?, ?, ?> getAppliedTransform(TransformTreeNode node) { + @SuppressWarnings({"rawtypes", "unchecked"}) + AppliedPTransform<?, ?, ?> application = AppliedPTransform.of( + node.getFullName(), node.getInput(), node.getOutput(), (PTransform) node.getTransform()); + return application; + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + toFinalize.add(value); + for (PValue expandedValue : value.expand()) { + valueToConsumers.put(expandedValue, new ArrayList<AppliedPTransform<?, ?, ?>>()); + if (expandedValue instanceof PCollectionView) { + views.add((PCollectionView<?>) expandedValue); + } + expandedValue.recordAsOutput(getAppliedTransform(producer)); + } + value.recordAsOutput(getAppliedTransform(producer)); + } + + private String genStepName() { + return String.format("s%s", numTransforms++); + } + + + /** + * Returns a mapping of each fully-expanded {@link PValue} to each + * {@link AppliedPTransform} that consumes it. For each AppliedPTransform in the collection + * returned from {@code getValueToCustomers().get(PValue)}, + * {@code AppliedPTransform#getInput().expand()} will contain the argument {@link PValue}. + */ + public Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> getValueToConsumers() { + checkState( + finalized, + "Can't call getValueToConsumers before the Pipeline has been completely traversed"); + + return valueToConsumers; + } + + /** + * Returns the mapping for each {@link AppliedPTransform} in the {@link Pipeline} to a unique step + * name. + */ + public Map<AppliedPTransform<?, ?, ?>, String> getStepNames() { + checkState( + finalized, "Can't call getStepNames before the Pipeline has been completely traversed"); + + return stepNames; + } + + /** + * Returns the root transforms of the {@link Pipeline}. A root {@link AppliedPTransform} consumes + * a {@link PInput} where the {@link PInput#expand()} returns an empty collection. + */ + public Collection<AppliedPTransform<?, ?, ?>> getRootTransforms() { + checkState( + finalized, + "Can't call getRootTransforms before the Pipeline has been completely traversed"); + + return rootTransforms; + } + + /** + * Returns all of the {@link PCollectionView PCollectionViews} contained in the visited + * {@link Pipeline}. + */ + public Collection<PCollectionView<?>> getViews() { + checkState(finalized, "Can't call getViews before the Pipeline has been completely traversed"); + + return views; + } + + /** + * Returns all of the {@link PValue PValues} that have been produced but not consumed. These + * {@link PValue PValues} should be finalized by the {@link PipelineRunner} before the + * {@link Pipeline} is executed. + */ + public Set<PValue> getUnfinalizedPValues() { + checkState( + finalized, + "Can't call getUnfinalizedPValues before the Pipeline has been completely traversed"); + + return toFinalize; + } +} + + http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java new file mode 100644 index 0000000..480bcde --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import java.util.concurrent.ExecutorService; + +/** + * A factory that creates {@link ExecutorService ExecutorServices}. + * {@link ExecutorService ExecutorServices} created by this factory should be independent of one + * another (e.g., if any executor is shut down the remaining executors should continue to process + * work). + */ +public interface ExecutorServiceFactory { + /** + * Create a new {@link ExecutorService}. + */ + ExecutorService create(); +} + http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java index ae686f2..c72a115 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java @@ -126,7 +126,7 @@ final class ExecutorServiceParallelExecutor implements InProcessExecutor { @Nullable final CommittedBundle<T> bundle, final CompletionCallback onComplete) { TransformExecutorService transformExecutor; - if (isKeyed(bundle.getPCollection())) { + if (bundle != null && isKeyed(bundle.getPCollection())) { final StepAndKey stepAndKey = StepAndKey.of(transform, bundle == null ? null : bundle.getKey()); transformExecutor = getSerialExecutorService(stepAndKey); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java index dec78d6..3ec4af1 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java @@ -59,7 +59,7 @@ class GroupByKeyEvaluatorFactory implements TransformEvaluatorFactory { CommittedBundle<?> inputBundle, InProcessEvaluationContext evaluationContext) { @SuppressWarnings({"cast", "unchecked", "rawtypes"}) - TransformEvaluator<InputT> evaluator = (TransformEvaluator<InputT>) createEvaluator( + TransformEvaluator<InputT> evaluator = createEvaluator( (AppliedPTransform) application, (CommittedBundle) inputBundle, evaluationContext); return evaluator; } @@ -184,7 +184,7 @@ class GroupByKeyEvaluatorFactory implements TransformEvaluatorFactory { extends ForwardingPTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> { private final GroupByKey<K, V> original; - public InProcessGroupByKey(GroupByKey<K, V> from) { + private InProcessGroupByKey(GroupByKey<K, V> from) { this.original = from; } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java index 27e9a4b..5ee0e88 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java @@ -15,20 +15,76 @@ */ package com.google.cloud.dataflow.sdk.runners.inprocess; +import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.options.ApplicationNameOptions; import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.Hidden; import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.Validation.Required; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * Options that can be used to configure the {@link InProcessPipelineRunner}. */ public interface InProcessPipelineOptions extends PipelineOptions, ApplicationNameOptions { + /** + * Gets the {@link ExecutorServiceFactory} to use to create instances of {@link ExecutorService} + * to execute {@link PTransform PTransforms}. + * + * <p>Note that {@link ExecutorService ExecutorServices} returned by the factory must ensure that + * it cannot enter a state in which it will not schedule additional pending work unless currently + * scheduled work completes, as this may cause the {@link Pipeline} to cease processing. + * + * <p>Defaults to a {@link CachedThreadPoolExecutorServiceFactory}, which produces instances of + * {@link Executors#newCachedThreadPool()}. + */ + @JsonIgnore + @Required + @Hidden + @Default.InstanceFactory(CachedThreadPoolExecutorServiceFactory.class) + ExecutorServiceFactory getExecutorServiceFactory(); + + void setExecutorServiceFactory(ExecutorServiceFactory executorService); + + /** + * Gets the {@link Clock} used by this pipeline. The clock is used in place of accessing the + * system time when time values are required by the evaluator. + */ @Default.InstanceFactory(NanosOffsetClock.Factory.class) + @JsonIgnore + @Required + @Hidden + @Description( + "The processing time source used by the pipeline. When the current time is " + + "needed by the evaluator, the result of clock#now() is used.") Clock getClock(); void setClock(Clock clock); + @Default.Boolean(false) + @Description( + "If the pipeline should shut down producers which have reached the maximum " + + "representable watermark. If this is set to true, a pipeline in which all PTransforms " + + "have reached the maximum watermark will be shut down, even if there are unbounded " + + "sources that could produce additional (late) data. By default, if the pipeline " + + "contains any unbounded PCollections, it will run until explicitly shut down.") boolean isShutdownUnboundedProducersWithMaxWatermark(); void setShutdownUnboundedProducersWithMaxWatermark(boolean shutdown); + + @Default.Boolean(true) + @Description( + "If the pipeline should block awaiting completion of the pipeline. If set to true, " + + "a call to Pipeline#run() will block until all PTransforms are complete. Otherwise, " + + "the Pipeline will execute asynchronously. If set to false, the completion of the " + + "pipeline can be awaited on by use of InProcessPipelineResult#awaitCompletion().") + boolean isBlockOnRun(); + + void setBlockOnRun(boolean b); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java index 32859da..a1c8756 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java @@ -1,5 +1,5 @@ /* - * Copyright (C) 2015 Google Inc. + * Copyright (C) 2016 Google Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of @@ -15,25 +15,46 @@ */ package com.google.cloud.dataflow.sdk.runners.inprocess; -import static com.google.common.base.Preconditions.checkArgument; - +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; +import com.google.cloud.dataflow.sdk.PipelineResult; import com.google.cloud.dataflow.sdk.annotations.Experimental; import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.AggregatorPipelineExtractor; +import com.google.cloud.dataflow.sdk.runners.AggregatorRetrievalException; +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKey; -import com.google.cloud.dataflow.sdk.runners.inprocess.ViewEvaluatorFactory.InProcessCreatePCollectionView; +import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; import com.google.cloud.dataflow.sdk.transforms.GroupByKey; import com.google.cloud.dataflow.sdk.transforms.PTransform; import com.google.cloud.dataflow.sdk.transforms.View.CreatePCollectionView; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.util.MapAggregatorValues; import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.UserCodeException; import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.joda.time.Instant; +import java.util.Collection; +import java.util.HashMap; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; import javax.annotation.Nullable; @@ -42,28 +63,25 @@ import javax.annotation.Nullable; * {@link PCollection PCollections}. */ @Experimental -public class InProcessPipelineRunner { - @SuppressWarnings({"rawtypes", "unused"}) +public class InProcessPipelineRunner + extends PipelineRunner<InProcessPipelineRunner.InProcessPipelineResult> { + /** + * The default set of transform overrides to use in the {@link InProcessPipelineRunner}. + * + * <p>A transform override must have a single-argument constructor that takes an instance of the + * type of transform it is overriding. + */ + @SuppressWarnings("rawtypes") private static Map<Class<? extends PTransform>, Class<? extends PTransform>> defaultTransformOverrides = ImmutableMap.<Class<? extends PTransform>, Class<? extends PTransform>>builder() + .put(Create.Values.class, InProcessCreate.class) .put(GroupByKey.class, InProcessGroupByKey.class) - .put(CreatePCollectionView.class, InProcessCreatePCollectionView.class) + .put( + CreatePCollectionView.class, + ViewEvaluatorFactory.InProcessCreatePCollectionView.class) .build(); - private static Map<Class<?>, TransformEvaluatorFactory> defaultEvaluatorFactories = - new ConcurrentHashMap<>(); - - /** - * Register a default transform evaluator. - */ - public static <TransformT extends PTransform<?, ?>> void registerTransformEvaluatorFactory( - Class<TransformT> clazz, TransformEvaluatorFactory evaluator) { - checkArgument(defaultEvaluatorFactories.put(clazz, evaluator) == null, - "Defining a default factory %s to evaluate Transforms of type %s multiple times", evaluator, - clazz); - } - /** * Part of a {@link PCollection}. Elements are output to a bundle, which will cause them to be * executed by {@link PTransform PTransforms} that consume the {@link PCollection} this bundle is @@ -73,7 +91,7 @@ public class InProcessPipelineRunner { */ public static interface UncommittedBundle<T> { /** - * Returns the PCollection that the elements of this bundle belong to. + * Returns the PCollection that the elements of this {@link UncommittedBundle} belong to. */ PCollection<T> getPCollection(); @@ -103,14 +121,13 @@ public class InProcessPipelineRunner { * @param <T> the type of elements contained within this bundle */ public static interface CommittedBundle<T> { - /** * Returns the PCollection that the elements of this bundle belong to. */ PCollection<T> getPCollection(); /** - * Returns weather this bundle is keyed. A bundle that is part of a {@link PCollection} that + * Returns whether this bundle is keyed. A bundle that is part of a {@link PCollection} that * occurs after a {@link GroupByKey} is keyed by the result of the last {@link GroupByKey}. */ boolean isKeyed(); @@ -119,11 +136,12 @@ public class InProcessPipelineRunner { * Returns the (possibly null) key that was output in the most recent {@link GroupByKey} in the * execution of this bundle. */ - @Nullable Object getKey(); + @Nullable + Object getKey(); /** - * @return an {@link Iterable} containing all of the elements that have been added to this - * {@link CommittedBundle} + * Returns an {@link Iterable} containing all of the elements that have been added to this + * {@link CommittedBundle}. */ Iterable<WindowedValue<T>> getElements(); @@ -166,4 +184,160 @@ public class InProcessPipelineRunner { public InProcessPipelineOptions getPipelineOptions() { return options; } + + @Override + public <OutputT extends POutput, InputT extends PInput> OutputT apply( + PTransform<InputT, OutputT> transform, InputT input) { + Class<?> overrideClass = defaultTransformOverrides.get(transform.getClass()); + if (overrideClass != null) { + // It is the responsibility of whoever constructs overrides to ensure this is type safe. + @SuppressWarnings("unchecked") + Class<PTransform<InputT, OutputT>> transformClass = + (Class<PTransform<InputT, OutputT>>) transform.getClass(); + + @SuppressWarnings("unchecked") + Class<PTransform<InputT, OutputT>> customTransformClass = + (Class<PTransform<InputT, OutputT>>) overrideClass; + + PTransform<InputT, OutputT> customTransform = + InstanceBuilder.ofType(customTransformClass) + .withArg(transformClass, transform) + .build(); + + // This overrides the contents of the apply method without changing the TransformTreeNode that + // is generated by the PCollection application. + return super.apply(customTransform, input); + } else { + return super.apply(transform, input); + } + } + + @Override + public InProcessPipelineResult run(Pipeline pipeline) { + ConsumerTrackingPipelineVisitor consumerTrackingVisitor = new ConsumerTrackingPipelineVisitor(); + pipeline.traverseTopologically(consumerTrackingVisitor); + for (PValue unfinalized : consumerTrackingVisitor.getUnfinalizedPValues()) { + unfinalized.finishSpecifying(); + } + @SuppressWarnings("rawtypes") + KeyedPValueTrackingVisitor keyedPValueVisitor = + KeyedPValueTrackingVisitor.create( + ImmutableSet.<Class<? extends PTransform>>of( + GroupByKey.class, InProcessGroupByKeyOnly.class)); + pipeline.traverseTopologically(keyedPValueVisitor); + + InProcessEvaluationContext context = + InProcessEvaluationContext.create( + getPipelineOptions(), + consumerTrackingVisitor.getRootTransforms(), + consumerTrackingVisitor.getValueToConsumers(), + consumerTrackingVisitor.getStepNames(), + consumerTrackingVisitor.getViews()); + + // independent executor service for each run + ExecutorService executorService = + context.getPipelineOptions().getExecutorServiceFactory().create(); + InProcessExecutor executor = + ExecutorServiceParallelExecutor.create( + executorService, + consumerTrackingVisitor.getValueToConsumers(), + keyedPValueVisitor.getKeyedPValues(), + TransformEvaluatorRegistry.defaultRegistry(), + context); + executor.start(consumerTrackingVisitor.getRootTransforms()); + + Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = + new AggregatorPipelineExtractor(pipeline).getAggregatorSteps(); + InProcessPipelineResult result = + new InProcessPipelineResult(executor, context, aggregatorSteps); + if (options.isBlockOnRun()) { + try { + result.awaitCompletion(); + } catch (UserCodeException userException) { + throw new PipelineExecutionException(userException.getCause()); + } catch (Throwable t) { + Throwables.propagate(t); + } + } + return result; + } + + /** + * The result of running a {@link Pipeline} with the {@link InProcessPipelineRunner}. + * + * Throws {@link UnsupportedOperationException} for all methods. + */ + public static class InProcessPipelineResult implements PipelineResult { + private final InProcessExecutor executor; + private final InProcessEvaluationContext evaluationContext; + private final Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps; + private State state; + + private InProcessPipelineResult( + InProcessExecutor executor, + InProcessEvaluationContext evaluationContext, + Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps) { + this.executor = executor; + this.evaluationContext = evaluationContext; + this.aggregatorSteps = aggregatorSteps; + // Only ever constructed after the executor has started. + this.state = State.RUNNING; + } + + @Override + public State getState() { + return state; + } + + @Override + public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator) + throws AggregatorRetrievalException { + CounterSet counters = evaluationContext.getCounters(); + Collection<PTransform<?, ?>> steps = aggregatorSteps.get(aggregator); + Map<String, T> stepValues = new HashMap<>(); + for (AppliedPTransform<?, ?, ?> transform : evaluationContext.getSteps()) { + if (steps.contains(transform.getTransform())) { + String stepName = + String.format( + "user-%s-%s", evaluationContext.getStepName(transform), aggregator.getName()); + Counter<T> counter = (Counter<T>) counters.getExistingCounter(stepName); + if (counter != null) { + stepValues.put(transform.getFullName(), counter.getAggregate()); + } + } + } + return new MapAggregatorValues<>(stepValues); + } + + /** + * Blocks until the {@link Pipeline} execution represented by this + * {@link InProcessPipelineResult} is complete, returning the terminal state. + * + * <p>If the pipeline terminates abnormally by throwing an exception, this will rethrow the + * exception. Future calls to {@link #getState()} will return + * {@link com.google.cloud.dataflow.sdk.PipelineResult.State#FAILED}. + * + * <p>NOTE: if the {@link Pipeline} contains an {@link IsBounded#UNBOUNDED unbounded} + * {@link PCollection}, and the {@link PipelineRunner} was created with + * {@link InProcessPipelineOptions#isShutdownUnboundedProducersWithMaxWatermark()} set to false, + * this method will never return. + * + * See also {@link InProcessExecutor#awaitCompletion()}. + */ + public State awaitCompletion() throws Throwable { + if (!state.isTerminal()) { + try { + executor.awaitCompletion(); + state = State.DONE; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw e; + } catch (Throwable t) { + state = State.FAILED; + throw t; + } + } + return state; + } + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java new file mode 100644 index 0000000..23a8c0f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PValue; + +import java.util.HashSet; +import java.util.Set; + +/** + * A pipeline visitor that tracks all keyed {@link PValue PValues}. A {@link PValue} is keyed if it + * is the result of a {@link PTransform} that produces keyed outputs. A {@link PTransform} that + * produces keyed outputs is assumed to colocate output elements that share a key. + * + * <p>All {@link GroupByKey} transforms, or their runner-specific implementation primitive, produce + * keyed output. + */ +// TODO: Handle Key-preserving transforms when appropriate and more aggressively make PTransforms +// unkeyed +class KeyedPValueTrackingVisitor implements PipelineVisitor { + @SuppressWarnings("rawtypes") + private final Set<Class<? extends PTransform>> producesKeyedOutputs; + private final Set<PValue> keyedValues; + private boolean finalized; + + public static KeyedPValueTrackingVisitor create( + @SuppressWarnings("rawtypes") Set<Class<? extends PTransform>> producesKeyedOutputs) { + return new KeyedPValueTrackingVisitor(producesKeyedOutputs); + } + + private KeyedPValueTrackingVisitor( + @SuppressWarnings("rawtypes") Set<Class<? extends PTransform>> producesKeyedOutputs) { + this.producesKeyedOutputs = producesKeyedOutputs; + this.keyedValues = new HashSet<>(); + } + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + checkState( + !finalized, + "Attempted to use a %s that has already been finalized on a pipeline (visiting node %s)", + KeyedPValueTrackingVisitor.class.getSimpleName(), + node); + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + checkState( + !finalized, + "Attempted to use a %s that has already been finalized on a pipeline (visiting node %s)", + KeyedPValueTrackingVisitor.class.getSimpleName(), + node); + if (node.isRootNode()) { + finalized = true; + } else if (producesKeyedOutputs.contains(node.getTransform().getClass())) { + keyedValues.addAll(node.getExpandedOutputs()); + } + } + + @Override + public void visitTransform(TransformTreeNode node) {} + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + if (producesKeyedOutputs.contains(producer.getTransform().getClass())) { + keyedValues.addAll(value.expand()); + } + } + + public Set<PValue> getKeyedPValues() { + checkState( + finalized, "can't call getKeyedPValues before a Pipeline has been completely traversed"); + return keyedValues; + } +} + http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java new file mode 100644 index 0000000..d921f6c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java @@ -0,0 +1,233 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.emptyIterable; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.io.CountingInput; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.PValue; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.List; + +/** + * Tests for {@link ConsumerTrackingPipelineVisitor}. + */ +@RunWith(JUnit4.class) +public class ConsumerTrackingPipelineVisitorTest implements Serializable { + @Rule public transient ExpectedException thrown = ExpectedException.none(); + + private transient TestPipeline p = TestPipeline.create(); + private transient ConsumerTrackingPipelineVisitor visitor = new ConsumerTrackingPipelineVisitor(); + + @Test + public void getViewsReturnsViews() { + PCollectionView<List<String>> listView = + p.apply("listCreate", Create.of("foo", "bar")) + .apply( + ParDo.of( + new DoFn<String, String>() { + @Override + public void processElement(DoFn<String, String>.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })) + .apply(View.<String>asList()); + PCollectionView<Object> singletonView = + p.apply("singletonCreate", Create.<Object>of(1, 2, 3)).apply(View.<Object>asSingleton()); + p.traverseTopologically(visitor); + assertThat( + visitor.getViews(), + Matchers.<PCollectionView<?>>containsInAnyOrder(listView, singletonView)); + } + + @Test + public void getRootTransformsContainsPBegins() { + PCollection<String> created = p.apply(Create.of("foo", "bar")); + PCollection<Long> counted = p.apply(CountingInput.upTo(1234L)); + PCollection<Long> unCounted = p.apply(CountingInput.unbounded()); + p.traverseTopologically(visitor); + assertThat( + visitor.getRootTransforms(), + Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( + created.getProducingTransformInternal(), + counted.getProducingTransformInternal(), + unCounted.getProducingTransformInternal())); + } + + @Test + public void getRootTransformsContainsEmptyFlatten() { + PCollection<String> empty = + PCollectionList.<String>empty(p).apply(Flatten.<String>pCollections()); + p.traverseTopologically(visitor); + assertThat( + visitor.getRootTransforms(), + Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( + empty.getProducingTransformInternal())); + } + + @Test + public void getValueToConsumersSucceeds() { + PCollection<String> created = p.apply(Create.of("1", "2", "3")); + PCollection<String> transformed = + created.apply( + ParDo.of( + new DoFn<String, String>() { + @Override + public void processElement(DoFn<String, String>.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })); + + PCollection<String> flattened = + PCollectionList.of(created).and(transformed).apply(Flatten.<String>pCollections()); + + p.traverseTopologically(visitor); + + assertThat( + visitor.getValueToConsumers().get(created), + Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( + transformed.getProducingTransformInternal(), + flattened.getProducingTransformInternal())); + assertThat( + visitor.getValueToConsumers().get(transformed), + Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( + flattened.getProducingTransformInternal())); + assertThat(visitor.getValueToConsumers().get(flattened), emptyIterable()); + } + + @Test + public void getUnfinalizedPValuesContainsDanglingOutputs() { + PCollection<String> created = p.apply(Create.of("1", "2", "3")); + PCollection<String> transformed = + created.apply( + ParDo.of( + new DoFn<String, String>() { + @Override + public void processElement(DoFn<String, String>.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })); + + p.traverseTopologically(visitor); + assertThat(visitor.getUnfinalizedPValues(), Matchers.<PValue>contains(transformed)); + } + + @Test + public void getUnfinalizedPValuesEmpty() { + p.apply(Create.of("1", "2", "3")) + .apply( + ParDo.of( + new DoFn<String, String>() { + @Override + public void processElement(DoFn<String, String>.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })) + .apply( + new PTransform<PInput, PDone>() { + @Override + public PDone apply(PInput input) { + return PDone.in(input.getPipeline()); + } + }); + + p.traverseTopologically(visitor); + assertThat(visitor.getUnfinalizedPValues(), emptyIterable()); + } + + @Test + public void traverseMultipleTimesThrows() { + p.apply(Create.of(1, 2, 3)); + + p.traverseTopologically(visitor); + thrown.expect(IllegalStateException.class); + thrown.expectMessage(ConsumerTrackingPipelineVisitor.class.getSimpleName()); + thrown.expectMessage("is finalized"); + p.traverseTopologically(visitor); + } + + @Test + public void traverseIndependentPathsSucceeds() { + p.apply("left", Create.of(1, 2, 3)); + p.apply("right", Create.of("foo", "bar", "baz")); + + p.traverseTopologically(visitor); + } + + @Test + public void getRootTransformsWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getRootTransforms"); + visitor.getRootTransforms(); + } + @Test + public void getStepNamesWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getStepNames"); + visitor.getStepNames(); + } + @Test + public void getUnfinalizedPValuesWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getUnfinalizedPValues"); + visitor.getUnfinalizedPValues(); + } + + @Test + public void getValueToConsumersWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getValueToConsumers"); + visitor.getValueToConsumers(); + } + + @Test + public void getViewsWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getViews"); + visitor.getViews(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java new file mode 100644 index 0000000..adb64cd --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessPipelineResult; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.transforms.SimpleFunction; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Tests for basic {@link InProcessPipelineRunner} functionality. + */ +@RunWith(JUnit4.class) +public class InProcessPipelineRunnerTest implements Serializable { + @Test + public void wordCountShouldSucceed() throws Throwable { + Pipeline p = getPipeline(); + + PCollection<KV<String, Long>> counts = + p.apply(Create.of("foo", "bar", "foo", "baz", "bar", "foo")) + .apply(MapElements.via(new SimpleFunction<String, String>() { + @Override + public String apply(String input) { + return input; + } + })) + .apply(Count.<String>perElement()); + PCollection<String> countStrs = + counts.apply(MapElements.via(new SimpleFunction<KV<String, Long>, String>() { + @Override + public String apply(KV<String, Long> input) { + String str = String.format("%s: %s", input.getKey(), input.getValue()); + return str; + } + })); + + DataflowAssert.that(countStrs).containsInAnyOrder("baz: 1", "bar: 2", "foo: 3"); + + InProcessPipelineResult result = ((InProcessPipelineResult) p.run()); + result.awaitCompletion(); + } + + private Pipeline getPipeline() { + PipelineOptions opts = PipelineOptionsFactory.create(); + opts.setRunner(InProcessPipelineRunner.class); + + Pipeline p = Pipeline.create(opts); + return p; + } +} + http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java new file mode 100644 index 0000000..0aaccc2 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java @@ -0,0 +1,189 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.Keys; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableSet; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Collections; +import java.util.Set; + +/** + * Tests for {@link KeyedPValueTrackingVisitor}. + */ +@RunWith(JUnit4.class) +public class KeyedPValueTrackingVisitorTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + + private KeyedPValueTrackingVisitor visitor; + private Pipeline p; + + @Before + public void setup() { + PipelineOptions options = PipelineOptionsFactory.create(); + + p = Pipeline.create(options); + @SuppressWarnings("rawtypes") + Set<Class<? extends PTransform>> producesKeyed = + ImmutableSet.<Class<? extends PTransform>>of(PrimitiveKeyer.class, CompositeKeyer.class); + visitor = KeyedPValueTrackingVisitor.create(producesKeyed); + } + + @Test + public void primitiveProducesKeyedOutputUnkeyedInputKeyedOutput() { + PCollection<Integer> keyed = + p.apply(Create.<Integer>of(1, 2, 3)).apply(new PrimitiveKeyer<Integer>()); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + } + + @Test + public void primitiveProducesKeyedOutputKeyedInputKeyedOutut() { + PCollection<Integer> keyed = + p.apply(Create.<Integer>of(1, 2, 3)) + .apply("firstKey", new PrimitiveKeyer<Integer>()) + .apply("secondKey", new PrimitiveKeyer<Integer>()); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + } + + @Test + public void compositeProducesKeyedOutputUnkeyedInputKeyedOutput() { + PCollection<Integer> keyed = + p.apply(Create.<Integer>of(1, 2, 3)).apply(new CompositeKeyer<Integer>()); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + } + + @Test + public void compositeProducesKeyedOutputKeyedInputKeyedOutut() { + PCollection<Integer> keyed = + p.apply(Create.<Integer>of(1, 2, 3)) + .apply("firstKey", new CompositeKeyer<Integer>()) + .apply("secondKey", new CompositeKeyer<Integer>()); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + } + + + @Test + public void noInputUnkeyedOutput() { + PCollection<KV<Integer, Iterable<Void>>> unkeyed = + p.apply( + Create.of(KV.<Integer, Iterable<Void>>of(-1, Collections.<Void>emptyList())) + .withCoder(KvCoder.of(VarIntCoder.of(), IterableCoder.of(VoidCoder.of())))); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed))); + } + + @Test + public void keyedInputNotProducesKeyedOutputUnkeyedOutput() { + PCollection<Integer> onceKeyed = + p.apply(Create.<Integer>of(1, 2, 3)) + .apply(new PrimitiveKeyer<Integer>()) + .apply(ParDo.of(new IdentityFn<Integer>())); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), not(hasItem(onceKeyed))); + } + + @Test + public void unkeyedInputNotProducesKeyedOutputUnkeyedOutput() { + PCollection<Integer> unkeyed = + p.apply(Create.<Integer>of(1, 2, 3)).apply(ParDo.of(new IdentityFn<Integer>())); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed))); + } + + @Test + public void traverseMultipleTimesThrows() { + p.apply( + Create.<KV<Integer, Void>>of( + KV.of(1, (Void) null), KV.of(2, (Void) null), KV.of(3, (Void) null)) + .withCoder(KvCoder.of(VarIntCoder.of(), VoidCoder.of()))) + .apply(GroupByKey.<Integer, Void>create()) + .apply(Keys.<Integer>create()); + + p.traverseTopologically(visitor); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("already been finalized"); + thrown.expectMessage(KeyedPValueTrackingVisitor.class.getSimpleName()); + p.traverseTopologically(visitor); + } + + @Test + public void getKeyedPValuesBeforeTraverseThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getKeyedPValues"); + visitor.getKeyedPValues(); + } + + private static class PrimitiveKeyer<K> extends PTransform<PCollection<K>, PCollection<K>> { + @Override + public PCollection<K> apply(PCollection<K> input) { + return PCollection.<K>createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded()) + .setCoder(input.getCoder()); + } + } + + private static class CompositeKeyer<K> extends PTransform<PCollection<K>, PCollection<K>> { + @Override + public PCollection<K> apply(PCollection<K> input) { + return input.apply(new PrimitiveKeyer<K>()).apply(ParDo.of(new IdentityFn<K>())); + } + } + + private static class IdentityFn<K> extends DoFn<K, K> { + @Override + public void processElement(DoFn<K, K>.ProcessContext c) throws Exception { + c.output(c.element()); + } + } +}