Repository: incubator-beam
Updated Branches:
  refs/heads/master 15e93c58e -> abd9fb3d7


Properly apply Transform Overrides in the Direct Runner

Previously the direct runner would use the transform override to
.apply(), but would keep the original transform in the pipeline,
e.g. it would use the original transform to look up an evaluator.

The current commit makes it use the node generated by applying the
override as a nested node within the graph (including, potentially
replacing it further recursively).

Additionally, makes InputProvider type-safe.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/67ce5313
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/67ce5313
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/67ce5313

Branch: refs/heads/master
Commit: 67ce53139b51617708cdf037d93c5195608accc5
Parents: 15e93c5
Author: Eugene Kirpichov <[email protected]>
Authored: Wed Nov 16 15:40:08 2016 -0800
Committer: Thomas Groh <[email protected]>
Committed: Wed Nov 16 18:20:36 2016 -0800

----------------------------------------------------------------------
 .../direct/BoundedReadEvaluatorFactory.java     | 28 ++++++++---------
 .../beam/runners/direct/DirectRunner.java       |  7 +++--
 .../beam/runners/direct/EmptyInputProvider.java | 22 +++++++-------
 .../direct/ExecutorServiceParallelExecutor.java | 12 ++++----
 .../beam/runners/direct/RootInputProvider.java  | 16 ++++++----
 .../runners/direct/RootProviderRegistry.java    | 19 ++++++------
 .../direct/TestStreamEvaluatorFactory.java      | 23 +++++++-------
 .../direct/TransformEvaluatorRegistry.java      |  5 +--
 .../direct/UnboundedReadEvaluatorFactory.java   | 32 +++++++++++---------
 9 files changed, 85 insertions(+), 79 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/67ce5313/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java
index 8becb91..66c55cd 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java
@@ -35,6 +35,7 @@ import 
org.apache.beam.runners.direct.DirectRunner.UncommittedBundle;
 import org.apache.beam.runners.direct.StepTransformResult.Builder;
 import org.apache.beam.sdk.io.BoundedSource;
 import org.apache.beam.sdk.io.BoundedSource.BoundedReader;
+import org.apache.beam.sdk.io.Read;
 import org.apache.beam.sdk.io.Read.Bounded;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
@@ -174,7 +175,8 @@ final class BoundedReadEvaluatorFactory implements 
TransformEvaluatorFactory {
     abstract BoundedSource<T> getSource();
   }
 
-  static class InputProvider implements RootInputProvider {
+  static class InputProvider<T>
+      implements RootInputProvider<T, BoundedSourceShard<T>, PBegin, 
Read.Bounded<T>> {
     private final EvaluationContext evaluationContext;
 
     InputProvider(EvaluationContext evaluationContext) {
@@ -182,27 +184,21 @@ final class BoundedReadEvaluatorFactory implements 
TransformEvaluatorFactory {
     }
 
     @Override
-    public Collection<CommittedBundle<?>> getInitialInputs(
-        AppliedPTransform<?, ?, ?> transform, int targetParallelism) throws 
Exception {
-      return createInitialSplits((AppliedPTransform) transform, 
targetParallelism);
-    }
-
-    private <OutputT>
-        Collection<CommittedBundle<BoundedSourceShard<OutputT>>> 
createInitialSplits(
-            AppliedPTransform<PBegin, ?, Bounded<OutputT>> transform, int 
targetParallelism)
-            throws Exception {
-      BoundedSource<OutputT> source = transform.getTransform().getSource();
+    public Collection<CommittedBundle<BoundedSourceShard<T>>> getInitialInputs(
+        AppliedPTransform<PBegin, PCollection<T>, Read.Bounded<T>> transform, 
int targetParallelism)
+        throws Exception {
+      BoundedSource<T> source = transform.getTransform().getSource();
       PipelineOptions options = evaluationContext.getPipelineOptions();
       long estimatedBytes = source.getEstimatedSizeBytes(options);
       long bytesPerBundle = estimatedBytes / targetParallelism;
-      List<? extends BoundedSource<OutputT>> bundles =
+      List<? extends BoundedSource<T>> bundles =
           source.splitIntoBundles(bytesPerBundle, options);
-      ImmutableList.Builder<CommittedBundle<BoundedSourceShard<OutputT>>> 
shards =
+      ImmutableList.Builder<CommittedBundle<BoundedSourceShard<T>>> shards =
           ImmutableList.builder();
-      for (BoundedSource<OutputT> bundle : bundles) {
-        CommittedBundle<BoundedSourceShard<OutputT>> inputShard =
+      for (BoundedSource<T> bundle : bundles) {
+        CommittedBundle<BoundedSourceShard<T>> inputShard =
             evaluationContext
-                .<BoundedSourceShard<OutputT>>createRootBundle()
+                .<BoundedSourceShard<T>>createRootBundle()
                 
.add(WindowedValue.valueInGlobalWindow(BoundedSourceShard.of(bundle)))
                 .commit(BoundedWindow.TIMESTAMP_MAX_VALUE);
         shards.add(inputShard);

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/67ce5313/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
index 04c8eb6..cce73c3 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
@@ -289,8 +289,9 @@ public class DirectRunner
     PTransformOverrideFactory overrideFactory = 
defaultTransformOverrides.get(transform.getClass());
     if (overrideFactory != null) {
       PTransform<InputT, OutputT> customTransform = 
overrideFactory.override(transform);
-
-      return super.apply(customTransform, input);
+      if (customTransform != transform) {
+        return Pipeline.applyTransform(transform.getName(), input, 
customTransform);
+      }
     }
     // If there is no override, or we should not apply the override, apply the 
original transform
     return super.apply(transform, input);
@@ -323,7 +324,7 @@ public class DirectRunner
             consumerTrackingVisitor.getStepNames(),
             consumerTrackingVisitor.getViews());
 
-    RootInputProvider rootInputProvider = 
RootProviderRegistry.defaultRegistry(context);
+    RootProviderRegistry rootInputProvider = 
RootProviderRegistry.defaultRegistry(context);
     TransformEvaluatorRegistry registry = 
TransformEvaluatorRegistry.defaultRegistry(context);
     PipelineExecutor executor =
         ExecutorServiceParallelExecutor.create(

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/67ce5313/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java
index 1058943..1185130 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java
@@ -21,16 +21,14 @@ import java.util.Collection;
 import java.util.Collections;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
+import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
 
-/**
- * A {@link RootInputProvider} that provides a singleton empty bundle.
- */
-class EmptyInputProvider implements RootInputProvider {
-  private final EvaluationContext evaluationContext;
-
-  EmptyInputProvider(EvaluationContext evaluationContext) {
-    this.evaluationContext = evaluationContext;
-  }
+/** A {@link RootInputProvider} that provides a singleton empty bundle. */
+class EmptyInputProvider<T>
+    implements RootInputProvider<T, Void, PCollectionList<T>, 
Flatten.FlattenPCollectionList<T>> {
+  EmptyInputProvider() {}
 
   /**
    * {@inheritDoc}.
@@ -38,8 +36,10 @@ class EmptyInputProvider implements RootInputProvider {
    * <p>Returns an empty collection.
    */
   @Override
-  public Collection<CommittedBundle<?>> getInitialInputs(
-      AppliedPTransform<?, ?, ?> transform, int targetParallelism) {
+  public Collection<CommittedBundle<Void>> getInitialInputs(
+      AppliedPTransform<PCollectionList<T>, PCollection<T>, 
Flatten.FlattenPCollectionList<T>>
+          transform,
+      int targetParallelism) {
     return Collections.emptyList();
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/67ce5313/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
index 0bb3d01..05cdd34 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
@@ -71,7 +71,7 @@ final class ExecutorServiceParallelExecutor implements 
PipelineExecutor {
 
   private final Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> 
valueToConsumers;
   private final Set<PValue> keyedPValues;
-  private final RootInputProvider rootInputProvider;
+  private final RootProviderRegistry rootProviderRegistry;
   private final TransformEvaluatorRegistry registry;
   @SuppressWarnings("rawtypes")
   private final Map<Class<? extends PTransform>, 
Collection<ModelEnforcementFactory>>
@@ -106,7 +106,7 @@ final class ExecutorServiceParallelExecutor implements 
PipelineExecutor {
       int targetParallelism,
       Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers,
       Set<PValue> keyedPValues,
-      RootInputProvider rootInputProvider,
+      RootProviderRegistry rootProviderRegistry,
       TransformEvaluatorRegistry registry,
       @SuppressWarnings("rawtypes")
           Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>>
@@ -116,7 +116,7 @@ final class ExecutorServiceParallelExecutor implements 
PipelineExecutor {
         targetParallelism,
         valueToConsumers,
         keyedPValues,
-        rootInputProvider,
+        rootProviderRegistry,
         registry,
         transformEnforcements,
         context);
@@ -126,7 +126,7 @@ final class ExecutorServiceParallelExecutor implements 
PipelineExecutor {
       int targetParallelism,
       Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers,
       Set<PValue> keyedPValues,
-      RootInputProvider rootInputProvider,
+      RootProviderRegistry rootProviderRegistry,
       TransformEvaluatorRegistry registry,
       @SuppressWarnings("rawtypes")
       Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> 
transformEnforcements,
@@ -135,7 +135,7 @@ final class ExecutorServiceParallelExecutor implements 
PipelineExecutor {
     this.executorService = Executors.newFixedThreadPool(targetParallelism);
     this.valueToConsumers = valueToConsumers;
     this.keyedPValues = keyedPValues;
-    this.rootInputProvider = rootInputProvider;
+    this.rootProviderRegistry = rootProviderRegistry;
     this.registry = registry;
     this.transformEnforcements = transformEnforcements;
     this.evaluationContext = context;
@@ -172,7 +172,7 @@ final class ExecutorServiceParallelExecutor implements 
PipelineExecutor {
       ConcurrentLinkedQueue<CommittedBundle<?>> pending = new 
ConcurrentLinkedQueue<>();
       try {
         Collection<CommittedBundle<?>> initialInputs =
-            rootInputProvider.getInitialInputs(root, numTargetSplits);
+            rootProviderRegistry.getInitialInputs(root, numTargetSplits);
         pending.addAll(initialInputs);
       } catch (Exception e) {
         throw UserCodeException.wrap(e);

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/67ce5313/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java
index 19d0040..c3df103 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java
@@ -23,12 +23,15 @@ import 
org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PInput;
 
 /**
- * Provides {@link CommittedBundle bundles} that will be provided to the
- * {@link PTransform PTransforms} that are at the root of a {@link Pipeline}.
+ * Provides {@link CommittedBundle bundles} that will be provided to the 
{@link PTransform
+ * PTransforms} that are at the root of a {@link Pipeline}.
  */
-interface RootInputProvider {
+interface RootInputProvider<
+    T, ShardT, InputT extends PInput, TransformT extends PTransform<InputT, 
PCollection<T>>> {
   /**
    * Get the initial inputs for the {@link AppliedPTransform}. The {@link 
AppliedPTransform} will be
    * provided with these {@link CommittedBundle bundles} as input when the 
{@link Pipeline} runs.
@@ -39,8 +42,9 @@ interface RootInputProvider {
    *
    * @param transform the {@link AppliedPTransform} to get initial inputs for.
    * @param targetParallelism the target amount of parallelism to obtain from 
the source. Must be
-   *                          greater than or equal to 1.
+   *     greater than or equal to 1.
    */
-  Collection<CommittedBundle<?>> getInitialInputs(
-      AppliedPTransform<?, ?, ?> transform, int targetParallelism) throws 
Exception;
+  Collection<CommittedBundle<ShardT>> getInitialInputs(
+      AppliedPTransform<InputT, PCollection<T>, TransformT> transform, int 
targetParallelism)
+      throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/67ce5313/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java
index bb5fcd2..e8a7665 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java
@@ -24,7 +24,6 @@ import java.util.Collection;
 import java.util.Map;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
 import org.apache.beam.sdk.io.Read;
-import org.apache.beam.sdk.testing.TestStream;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Flatten.FlattenPCollectionList;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -33,25 +32,27 @@ import org.apache.beam.sdk.transforms.PTransform;
  * A {@link RootInputProvider} that delegates to primitive {@link 
RootInputProvider} implementations
  * based on the type of {@link PTransform} of the application.
  */
-class RootProviderRegistry implements RootInputProvider {
+class RootProviderRegistry {
   public static RootProviderRegistry defaultRegistry(EvaluationContext 
context) {
-    ImmutableMap.Builder<Class<? extends PTransform>, RootInputProvider> 
defaultProviders =
-        ImmutableMap.builder();
+    ImmutableMap.Builder<Class<? extends PTransform>, RootInputProvider<?, ?, 
?, ?>>
+        defaultProviders = ImmutableMap.builder();
     defaultProviders
         .put(Read.Bounded.class, new 
BoundedReadEvaluatorFactory.InputProvider(context))
         .put(Read.Unbounded.class, new 
UnboundedReadEvaluatorFactory.InputProvider(context))
-        .put(TestStream.class, new 
TestStreamEvaluatorFactory.InputProvider(context))
-        .put(FlattenPCollectionList.class, new EmptyInputProvider(context));
+        .put(
+            
TestStreamEvaluatorFactory.DirectTestStreamFactory.DirectTestStream.class,
+            new TestStreamEvaluatorFactory.InputProvider(context))
+        .put(FlattenPCollectionList.class, new EmptyInputProvider());
     return new RootProviderRegistry(defaultProviders.build());
   }
 
-  private final Map<Class<? extends PTransform>, RootInputProvider> providers;
+  private final Map<Class<? extends PTransform>, RootInputProvider<?, ?, ?, 
?>> providers;
 
-  private RootProviderRegistry(Map<Class<? extends PTransform>, 
RootInputProvider> providers) {
+  private RootProviderRegistry(
+      Map<Class<? extends PTransform>, RootInputProvider<?, ?, ?, ?>> 
providers) {
     this.providers = providers;
   }
 
-  @Override
   public Collection<CommittedBundle<?>> getInitialInputs(
       AppliedPTransform<?, ?, ?> transform, int targetParallelism) throws 
Exception {
     Class<? extends PTransform> transformClass = 
transform.getTransform().getClass();

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/67ce5313/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
index 58f2fa9..2ab6adf 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
@@ -162,7 +162,7 @@ class TestStreamEvaluatorFactory implements 
TransformEvaluatorFactory {
       return new DirectTestStream<>(transform);
     }
 
-    private static class DirectTestStream<T> extends PTransform<PBegin, 
PCollection<T>> {
+    static class DirectTestStream<T> extends PTransform<PBegin, 
PCollection<T>> {
       private final TestStream<T> original;
 
       private DirectTestStream(TestStream<T> transform) {
@@ -185,7 +185,9 @@ class TestStreamEvaluatorFactory implements 
TransformEvaluatorFactory {
     }
   }
 
-  static class InputProvider implements RootInputProvider {
+  static class InputProvider<T>
+      implements RootInputProvider<
+          T, TestStreamIndex<T>, PBegin, 
DirectTestStreamFactory.DirectTestStream<T>> {
     private final EvaluationContext evaluationContext;
 
     InputProvider(EvaluationContext evaluationContext) {
@@ -193,19 +195,18 @@ class TestStreamEvaluatorFactory implements 
TransformEvaluatorFactory {
     }
 
     @Override
-    public Collection<CommittedBundle<?>> getInitialInputs(
-        AppliedPTransform<?, ?, ?> transform, int targetParallelism) {
-      return createInputBundle((AppliedPTransform) transform);
-    }
-
-    private <T> Collection<CommittedBundle<?>> createInputBundle(
-        AppliedPTransform<PBegin, ?, TestStream<T>> transform) {
+    public Collection<CommittedBundle<TestStreamIndex<T>>> getInitialInputs(
+        AppliedPTransform<PBegin, PCollection<T>, 
DirectTestStreamFactory.DirectTestStream<T>>
+            transform,
+        int targetParallelism) {
       CommittedBundle<TestStreamIndex<T>> initialBundle =
           evaluationContext
               .<TestStreamIndex<T>>createRootBundle()
-              
.add(WindowedValue.valueInGlobalWindow(TestStreamIndex.of(transform.getTransform())))
+              .add(
+                  WindowedValue.valueInGlobalWindow(
+                      TestStreamIndex.of(transform.getTransform().original)))
               .commit(BoundedWindow.TIMESTAMP_MAX_VALUE);
-      return Collections.<CommittedBundle<?>>singleton(initialBundle);
+      return Collections.singleton(initialBundle);
     }
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/67ce5313/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
index f384a14..51502f7 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
@@ -29,7 +29,6 @@ import 
org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupAlsoByWindow;
 import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
 import org.apache.beam.sdk.io.Read;
-import org.apache.beam.sdk.testing.TestStream;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Flatten.FlattenPCollectionList;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -62,7 +61,9 @@ class TransformEvaluatorRegistry implements 
TransformEvaluatorFactory {
             // Runner-specific primitives used in expansion of GroupByKey
             .put(DirectGroupByKeyOnly.class, new 
GroupByKeyOnlyEvaluatorFactory(ctxt))
             .put(DirectGroupAlsoByWindow.class, new 
GroupAlsoByWindowEvaluatorFactory(ctxt))
-            .put(TestStream.class, new TestStreamEvaluatorFactory(ctxt))
+            .put(
+                
TestStreamEvaluatorFactory.DirectTestStreamFactory.DirectTestStream.class,
+                new TestStreamEvaluatorFactory(ctxt))
             .build();
     return new TransformEvaluatorRegistry(primitives);
   }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/67ce5313/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java
index fb09b3e..24a91cb 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java
@@ -77,8 +77,7 @@ class UnboundedReadEvaluatorFactory implements 
TransformEvaluatorFactory {
 
   private <OutputT> TransformEvaluator<?> createEvaluator(
       AppliedPTransform<PBegin, PCollection<OutputT>, Read.Unbounded<OutputT>> 
application) {
-    return new UnboundedReadEvaluator<>(
-        application, evaluationContext, readerReuseChance);
+    return new UnboundedReadEvaluator<>(application, evaluationContext, 
readerReuseChance);
   }
 
   @Override
@@ -128,8 +127,9 @@ class UnboundedReadEvaluatorFactory implements 
TransformEvaluatorFactory {
           int numElements = 0;
           do {
             if (deduplicator.shouldOutput(reader.getCurrentRecordId())) {
-              
output.add(WindowedValue.timestampedValueInGlobalWindow(reader.getCurrent(),
-                  reader.getCurrentTimestamp()));
+              output.add(
+                  WindowedValue.timestampedValueInGlobalWindow(
+                      reader.getCurrent(), reader.getCurrentTimestamp()));
             }
             numElements++;
           } while (numElements < ARBITRARY_MAX_ELEMENTS && reader.advance());
@@ -251,9 +251,12 @@ class UnboundedReadEvaluatorFactory implements 
TransformEvaluatorFactory {
     }
 
     abstract UnboundedSource<T, CheckpointT> getSource();
+
     abstract UnboundedReadDeduplicator getDeduplicator();
+
     @Nullable
     abstract UnboundedReader<T> getExistingReader();
+
     @Nullable
     abstract CheckpointT getCheckpoint();
 
@@ -262,7 +265,9 @@ class UnboundedReadEvaluatorFactory implements 
TransformEvaluatorFactory {
     }
   }
 
-  static class InputProvider implements RootInputProvider {
+  static class InputProvider<OutputT>
+      implements RootInputProvider<
+          OutputT, UnboundedSourceShard<OutputT, ?>, PBegin, 
Unbounded<OutputT>> {
     private final EvaluationContext evaluationContext;
 
     InputProvider(EvaluationContext evaluationContext) {
@@ -270,13 +275,9 @@ class UnboundedReadEvaluatorFactory implements 
TransformEvaluatorFactory {
     }
 
     @Override
-    public Collection<CommittedBundle<?>> getInitialInputs(
-        AppliedPTransform<?, ?, ?> transform, int targetParallelism) throws 
Exception {
-      return createInitialSplits((AppliedPTransform) transform, 
targetParallelism);
-    }
-
-    private <OutputT> Collection<CommittedBundle<?>> createInitialSplits(
-        AppliedPTransform<PBegin, ?, Unbounded<OutputT>> transform, int 
targetParallelism)
+    public Collection<CommittedBundle<UnboundedSourceShard<OutputT, ?>>> 
getInitialInputs(
+        AppliedPTransform<PBegin, PCollection<OutputT>, Unbounded<OutputT>> 
transform,
+        int targetParallelism)
         throws Exception {
       UnboundedSource<OutputT, ?> source = 
transform.getTransform().getSource();
       List<? extends UnboundedSource<OutputT, ?>> splits =
@@ -286,14 +287,15 @@ class UnboundedReadEvaluatorFactory implements 
TransformEvaluatorFactory {
               ? UnboundedReadDeduplicator.CachedIdDeduplicator.create()
               : NeverDeduplicator.create();
 
-      ImmutableList.Builder<CommittedBundle<?>> initialShards = 
ImmutableList.builder();
+      ImmutableList.Builder<CommittedBundle<UnboundedSourceShard<OutputT, ?>>> 
initialShards =
+          ImmutableList.builder();
       for (UnboundedSource<OutputT, ?> split : splits) {
         UnboundedSourceShard<OutputT, ?> shard =
             UnboundedSourceShard.unstarted(split, deduplicator);
         initialShards.add(
             evaluationContext
-                .<UnboundedSourceShard<?, ?>>createRootBundle()
-                .add(WindowedValue.<UnboundedSourceShard<?, 
?>>valueInGlobalWindow(shard))
+                .<UnboundedSourceShard<OutputT, ?>>createRootBundle()
+                .add(WindowedValue.<UnboundedSourceShard<OutputT, 
?>>valueInGlobalWindow(shard))
                 .commit(BoundedWindow.TIMESTAMP_MAX_VALUE));
       }
       return initialShards.build();

Reply via email to