Move responsibility for knowing about keyedness into EvaluationContext

This will allow transform evaluators to inquire about whether
various collections are keyed.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/b26ceaa3
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/b26ceaa3
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/b26ceaa3

Branch: refs/heads/python-sdk
Commit: b26ceaa347c4bc50abfb4c3c138167a25a99cf57
Parents: 81702e6
Author: Kenneth Knowles <k...@google.com>
Authored: Thu Dec 8 13:28:44 2016 -0800
Committer: Kenneth Knowles <k...@google.com>
Committed: Tue Dec 20 11:18:04 2016 -0800

----------------------------------------------------------------------
 .../beam/runners/direct/DirectRunner.java       |  4 +--
 .../beam/runners/direct/EvaluationContext.java  | 26 +++++++++++++++++---
 .../direct/ExecutorServiceParallelExecutor.java |  8 +-----
 .../runners/direct/EvaluationContextTest.java   |  9 ++++++-
 4 files changed, 34 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b26ceaa3/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
index afa43ff..7e6ea15 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
@@ -315,14 +315,14 @@ public class DirectRunner extends 
PipelineRunner<DirectPipelineResult> {
             getPipelineOptions(),
             clockSupplier.get(),
             Enforcement.bundleFactoryFor(enabledEnforcements, graph),
-            graph);
+            graph,
+            keyedPValueVisitor.getKeyedPValues());
 
     RootProviderRegistry rootInputProvider = 
RootProviderRegistry.defaultRegistry(context);
     TransformEvaluatorRegistry registry = 
TransformEvaluatorRegistry.defaultRegistry(context);
     PipelineExecutor executor =
         ExecutorServiceParallelExecutor.create(
             options.getTargetParallelism(), graph,
-            keyedPValueVisitor.getKeyedPValues(),
             rootInputProvider,
             registry,
             Enforcement.defaultModelEnforcements(enabledEnforcements),

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b26ceaa3/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
index 230d91b..cb9ddd8 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
@@ -27,6 +27,7 @@ import java.util.Collection;
 import java.util.EnumSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
 import javax.annotation.Nullable;
@@ -99,17 +100,28 @@ class EvaluationContext {
 
   private final DirectMetrics metrics;
 
+  private final Set<PValue> keyedPValues;
+
   public static EvaluationContext create(
-      DirectOptions options, Clock clock, BundleFactory bundleFactory, 
DirectGraph graph) {
-    return new EvaluationContext(options, clock, bundleFactory, graph);
+      DirectOptions options,
+      Clock clock,
+      BundleFactory bundleFactory,
+      DirectGraph graph,
+      Set<PValue> keyedPValues) {
+    return new EvaluationContext(options, clock, bundleFactory, graph, 
keyedPValues);
   }
 
   private EvaluationContext(
-      DirectOptions options, Clock clock, BundleFactory bundleFactory, 
DirectGraph graph) {
+      DirectOptions options,
+      Clock clock,
+      BundleFactory bundleFactory,
+      DirectGraph graph,
+      Set<PValue> keyedPValues) {
     this.options = checkNotNull(options);
     this.clock = clock;
     this.bundleFactory = checkNotNull(bundleFactory);
     this.graph = checkNotNull(graph);
+    this.keyedPValues = keyedPValues;
 
     this.watermarkManager = WatermarkManager.create(clock, graph);
     this.sideInputContainer = SideInputContainer.create(this, 
graph.getViews());
@@ -244,6 +256,14 @@ class EvaluationContext {
   }
 
   /**
+   * Indicate whether or not this {@link PCollection} has been determined to be
+   * keyed.
+   */
+  public <T> boolean isKeyed(PValue pValue) {
+    return keyedPValues.contains(pValue);
+  }
+
+  /**
    * Create a {@link PCollectionViewWriter}, whose elements will be used in 
the provided
    * {@link PCollectionView}.
    */

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b26ceaa3/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
index a308295..5a653b7 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
@@ -31,7 +31,6 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.Map;
 import java.util.Queue;
-import java.util.Set;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentLinkedQueue;
@@ -70,7 +69,6 @@ final class ExecutorServiceParallelExecutor implements 
PipelineExecutor {
   private final ExecutorService executorService;
 
   private final DirectGraph graph;
-  private final Set<PValue> keyedPValues;
   private final RootProviderRegistry rootProviderRegistry;
   private final TransformEvaluatorRegistry registry;
   @SuppressWarnings("rawtypes")
@@ -105,7 +103,6 @@ final class ExecutorServiceParallelExecutor implements 
PipelineExecutor {
   public static ExecutorServiceParallelExecutor create(
       int targetParallelism,
       DirectGraph graph,
-      Set<PValue> keyedPValues,
       RootProviderRegistry rootProviderRegistry,
       TransformEvaluatorRegistry registry,
       @SuppressWarnings("rawtypes")
@@ -115,7 +112,6 @@ final class ExecutorServiceParallelExecutor implements 
PipelineExecutor {
     return new ExecutorServiceParallelExecutor(
         targetParallelism,
         graph,
-        keyedPValues,
         rootProviderRegistry,
         registry,
         transformEnforcements,
@@ -125,7 +121,6 @@ final class ExecutorServiceParallelExecutor implements 
PipelineExecutor {
   private ExecutorServiceParallelExecutor(
       int targetParallelism,
       DirectGraph graph,
-      Set<PValue> keyedPValues,
       RootProviderRegistry rootProviderRegistry,
       TransformEvaluatorRegistry registry,
       @SuppressWarnings("rawtypes")
@@ -134,7 +129,6 @@ final class ExecutorServiceParallelExecutor implements 
PipelineExecutor {
     this.targetParallelism = targetParallelism;
     this.executorService = Executors.newFixedThreadPool(targetParallelism);
     this.graph = graph;
-    this.keyedPValues = keyedPValues;
     this.rootProviderRegistry = rootProviderRegistry;
     this.registry = registry;
     this.transformEnforcements = transformEnforcements;
@@ -229,7 +223,7 @@ final class ExecutorServiceParallelExecutor implements 
PipelineExecutor {
   }
 
   private boolean isKeyed(PValue pvalue) {
-    return keyedPValues.contains(pvalue);
+    return evaluationContext.isKeyed(pvalue);
   }
 
   private void scheduleConsumers(ExecutorUpdate update) {

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b26ceaa3/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java
 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java
index bf36204..15340da 100644
--- 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java
+++ 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java
@@ -105,11 +105,18 @@ public class EvaluationContextTest {
     view = created.apply(View.<Integer>asIterable());
     unbounded = p.apply(CountingInput.unbounded());
 
+    KeyedPValueTrackingVisitor keyedPValueTrackingVisitor = 
KeyedPValueTrackingVisitor.create();
+    p.traverseTopologically(keyedPValueTrackingVisitor);
+
     BundleFactory bundleFactory = ImmutableListBundleFactory.create();
     graph = DirectGraphs.getGraph(p);
     context =
         EvaluationContext.create(
-            runner.getPipelineOptions(), NanosOffsetClock.create(), 
bundleFactory, graph);
+            runner.getPipelineOptions(),
+            NanosOffsetClock.create(),
+            bundleFactory,
+            graph,
+            keyedPValueTrackingVisitor.getKeyedPValues());
 
     createdProducer = graph.getProducer(created);
     downstreamProducer = graph.getProducer(downstream);

Reply via email to