Repository: beam
Updated Branches:
  refs/heads/master cf9d2211f -> 64102943f


Add Input Reconstruction to PTransformOverrideFactory

Inputs are only ever provided as expanded representations. Overrides,
however, may be applied to compressed inputs. Add a method to
reconstruct the language-specific composite input.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/078a2ff5
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/078a2ff5
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/078a2ff5

Branch: refs/heads/master
Commit: 078a2ff54ecaa0d7d66b2a42fd135a5325722958
Parents: cf9d221
Author: Thomas Groh <[email protected]>
Authored: Tue Feb 7 09:42:32 2017 -0800
Committer: Thomas Groh <[email protected]>
Committed: Wed Feb 8 16:18:30 2017 -0800

----------------------------------------------------------------------
 ...ectGBKIntoKeyedWorkItemsOverrideFactory.java | 10 ++++
 .../direct/DirectGroupByKeyOverrideFactory.java | 10 ++++
 .../direct/ParDoMultiOverrideFactory.java       | 10 ++++
 .../ParDoSingleViaMultiOverrideFactory.java     | 12 ++++-
 .../direct/TestStreamEvaluatorFactory.java      |  7 +++
 .../runners/direct/ViewEvaluatorFactory.java    |  8 +++
 .../direct/WriteWithShardingFactory.java        | 10 ++++
 .../DirectGroupByKeyOverrideFactoryTest.java    | 51 ++++++++++++++++++++
 .../direct/ParDoMultiOverrideFactoryTest.java   | 45 +++++++++++++++++
 .../ParDoSingleViaMultiOverrideFactoryTest.java | 45 +++++++++++++++++
 .../direct/TestStreamEvaluatorFactoryTest.java  | 11 +++++
 .../direct/ViewEvaluatorFactoryTest.java        |  9 ++++
 .../direct/WriteWithShardingFactoryTest.java    |  9 ++++
 .../sdk/runners/PTransformOverrideFactory.java  |  8 +++
 14 files changed, 244 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java
index ab4c114..caf61db 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java
@@ -17,12 +17,16 @@
  */
 package org.apache.beam.runners.direct;
 
+import com.google.common.collect.Iterables;
+import java.util.List;
 import org.apache.beam.runners.core.KeyedWorkItem;
 import org.apache.beam.runners.core.SplittableParDo.GBKIntoKeyedWorkItems;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TaggedPValue;
 
 /**
  * Provides an implementation of {@link SplittableParDo.GBKIntoKeyedWorkItems} 
for the Direct
@@ -37,4 +41,10 @@ class DirectGBKIntoKeyedWorkItemsOverrideFactory<KeyT, 
InputT>
       getReplacementTransform(GBKIntoKeyedWorkItems<KeyT, InputT> transform) {
     return new DirectGroupByKey.DirectGroupByKeyOnly<>();
   }
+
+  @Override
+  public PCollection<KV<KeyT, InputT>> getInput(
+      List<TaggedPValue> inputs, Pipeline p) {
+    return (PCollection<KV<KeyT, InputT>>) 
Iterables.getOnlyElement(inputs).getValue();
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java
index 7cf3256..8a5413b 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java
@@ -17,11 +17,15 @@
  */
 package org.apache.beam.runners.direct;
 
+import com.google.common.collect.Iterables;
+import java.util.List;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TaggedPValue;
 
 /** A {@link PTransformOverrideFactory} for {@link GroupByKey} PTransforms. */
 final class DirectGroupByKeyOverrideFactory<K, V>
@@ -32,4 +36,10 @@ final class DirectGroupByKeyOverrideFactory<K, V>
       GroupByKey<K, V> transform) {
     return new DirectGroupByKey<>(transform);
   }
+
+  @Override
+  public PCollection<KV<K, V>> getInput(
+      List<TaggedPValue> inputs, Pipeline p) {
+    return (PCollection<KV<K, V>>) Iterables.getOnlyElement(inputs).getValue();
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
index ceb35ec..483b7ce 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
@@ -19,10 +19,13 @@ package org.apache.beam.runners.direct;
 
 import static com.google.common.base.Preconditions.checkState;
 
+import com.google.common.collect.Iterables;
+import java.util.List;
 import org.apache.beam.runners.core.KeyedWorkItem;
 import org.apache.beam.runners.core.KeyedWorkItemCoder;
 import org.apache.beam.runners.core.KeyedWorkItems;
 import org.apache.beam.runners.core.SplittableParDo;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
@@ -44,6 +47,7 @@ import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.TaggedPValue;
 import org.apache.beam.sdk.values.TupleTagList;
 import org.apache.beam.sdk.values.TypedPValue;
 
@@ -77,6 +81,12 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
     }
   }
 
+  @Override
+  public PCollection<? extends InputT> getInput(
+      List<TaggedPValue> inputs, Pipeline p) {
+    return (PCollection<? extends InputT>) 
Iterables.getOnlyElement(inputs).getValue();
+  }
+
   static class GbkThenStatefulParDo<K, InputT, OutputT>
       extends PTransform<PCollection<KV<K, InputT>>, PCollectionTuple> {
     private final ParDo.BoundMulti<KV<K, InputT>, OutputT> underlyingParDo;

http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
index 3ae3382..6da5bb4 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
@@ -17,12 +17,16 @@
  */
 package org.apache.beam.runners.direct;
 
+import com.google.common.collect.Iterables;
+import java.util.List;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.ParDo.Bound;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.TaggedPValue;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
 
@@ -32,13 +36,19 @@ import org.apache.beam.sdk.values.TupleTagList;
  */
 class ParDoSingleViaMultiOverrideFactory<InputT, OutputT>
     implements PTransformOverrideFactory<
-        PCollection<? extends InputT>, PCollection<OutputT>, Bound<InputT, 
OutputT>>{
+        PCollection<? extends InputT>, PCollection<OutputT>, Bound<InputT, 
OutputT>> {
   @Override
   public PTransform<PCollection<? extends InputT>, PCollection<OutputT>> 
getReplacementTransform(
       Bound<InputT, OutputT> transform) {
     return new ParDoSingleViaMulti<>(transform);
   }
 
+  @Override
+  public PCollection<? extends InputT> getInput(
+      List<TaggedPValue> inputs, Pipeline p) {
+    return (PCollection<? extends InputT>) 
Iterables.getOnlyElement(inputs).getValue();
+  }
+
   static class ParDoSingleViaMulti<InputT, OutputT>
       extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> {
     private static final String MAIN_OUTPUT_TAG = "main";

http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
index bdf293f..b81d7d5 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
@@ -31,6 +31,7 @@ import java.util.concurrent.atomic.AtomicReference;
 import javax.annotation.Nullable;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
 import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.testing.TestStream;
@@ -47,6 +48,7 @@ import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollection.IsBounded;
+import org.apache.beam.sdk.values.TaggedPValue;
 import org.apache.beam.sdk.values.TimestampedValue;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
@@ -168,6 +170,11 @@ class TestStreamEvaluatorFactory implements 
TransformEvaluatorFactory {
       return new DirectTestStream<>(transform);
     }
 
+    @Override
+    public PBegin getInput(List<TaggedPValue> inputs, Pipeline p) {
+      return p.begin();
+    }
+
     static class DirectTestStream<T> extends PTransform<PBegin, 
PCollection<T>> {
       private final TestStream<T> original;
 

http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java
index fcd8423..817fb33 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java
@@ -23,6 +23,7 @@ import java.util.List;
 import org.apache.beam.runners.direct.CommittedResult.OutputType;
 import org.apache.beam.runners.direct.DirectRunner.PCollectionViewWriter;
 import org.apache.beam.runners.direct.StepTransformResult.Builder;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
@@ -35,6 +36,7 @@ import org.apache.beam.sdk.transforms.WithKeys;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TaggedPValue;
 
 /**
  * The {@link DirectRunner} {@link TransformEvaluatorFactory} for the
@@ -105,6 +107,12 @@ class ViewEvaluatorFactory implements 
TransformEvaluatorFactory {
         CreatePCollectionView<ElemT, ViewT> transform) {
       return new DirectCreatePCollectionView<>(transform);
     }
+
+    @Override
+    public PCollection<ElemT> getInput(
+        List<TaggedPValue> inputs, Pipeline p) {
+      return (PCollection<ElemT>) Iterables.getOnlyElement(inputs).getValue();
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
index fd1c175..9f5f4bd 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
@@ -21,7 +21,10 @@ package org.apache.beam.runners.direct;
 import static com.google.common.base.Preconditions.checkArgument;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Iterables;
+import java.util.List;
 import java.util.concurrent.ThreadLocalRandom;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.io.Write;
 import org.apache.beam.sdk.io.Write.Bound;
 import org.apache.beam.sdk.transforms.Count;
@@ -39,6 +42,7 @@ import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollection.IsBounded;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.TaggedPValue;
 import org.joda.time.Duration;
 
 /**
@@ -60,6 +64,12 @@ class WriteWithShardingFactory<InputT>
     return transform;
   }
 
+  @Override
+  public PCollection<InputT> getInput(
+      List<TaggedPValue> inputs, Pipeline p) {
+    return (PCollection<InputT>) Iterables.getOnlyElement(inputs).getValue();
+  }
+
   private static class DynamicallyReshardedWrite<T> extends 
PTransform<PCollection<T>, PDone> {
     private final transient Write.Bound<T> original;
 

http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java
 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java
new file mode 100644
index 0000000..03f1dda
--- /dev/null
+++ 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static org.junit.Assert.assertThat;
+
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.hamcrest.Matchers;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link DirectGBKIntoKeyedWorkItemsOverrideFactory}.
+ */
+@RunWith(JUnit4.class)
+public class DirectGroupByKeyOverrideFactoryTest {
+  private DirectGroupByKeyOverrideFactory factory = new 
DirectGroupByKeyOverrideFactory();
+  @Test
+  public void getInputSucceeds() {
+    TestPipeline p = TestPipeline.create();
+    PCollection<KV<String, Integer>> input =
+        p.apply(
+            Create.of(KV.of("foo", 1))
+                .withCoder(KvCoder.of(StringUtf8Coder.of(), 
VarIntCoder.of())));
+    PCollection<?> reconstructed = factory.getInput(input.expand(), p);
+    assertThat(reconstructed, Matchers.<PCollection<?>>equalTo(input));
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java
 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java
new file mode 100644
index 0000000..4bbf924
--- /dev/null
+++ 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static org.junit.Assert.assertThat;
+
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.PCollection;
+import org.hamcrest.Matchers;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link ParDoMultiOverrideFactory}.
+ */
+@RunWith(JUnit4.class)
+public class ParDoMultiOverrideFactoryTest {
+  private ParDoMultiOverrideFactory factory = new ParDoMultiOverrideFactory();
+
+  @Test
+  public void getInputSucceeds() {
+    TestPipeline p = TestPipeline.create();
+    PCollection<Integer> input = p.apply(Create.of(1, 2, 3));
+    PCollection<?> reconstructed = factory.getInput(input.expand(), p);
+    assertThat(reconstructed, Matchers.<PCollection<?>>equalTo(input));
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java
 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java
new file mode 100644
index 0000000..8f170dd
--- /dev/null
+++ 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static org.junit.Assert.assertThat;
+
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.PCollection;
+import org.hamcrest.Matchers;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link ParDoSingleViaMultiOverrideFactory}.
+ */
+@RunWith(JUnit4.class)
+public class ParDoSingleViaMultiOverrideFactoryTest {
+  private ParDoSingleViaMultiOverrideFactory factory = new 
ParDoSingleViaMultiOverrideFactory();
+
+  @Test
+  public void getInputSucceeds() {
+    TestPipeline p = TestPipeline.create();
+    PCollection<Integer> input = p.apply(Create.of(1, 2, 3));
+    PCollection<?> reconstructed = factory.getInput(input.expand(), p);
+    assertThat(reconstructed, Matchers.<PCollection<?>>equalTo(input));
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java
 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java
index c5b3b3d..4dc7738 100644
--- 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java
+++ 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java
@@ -27,15 +27,19 @@ import com.google.common.collect.Iterables;
 import java.util.Collection;
 import java.util.Collections;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
+import 
org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory;
 import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.TestClock;
 import 
org.apache.beam.runners.direct.TestStreamEvaluatorFactory.TestStreamIndex;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.testing.TestStream;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TaggedPValue;
 import org.apache.beam.sdk.values.TimestampedValue;
 import org.hamcrest.Matchers;
 import org.joda.time.Duration;
@@ -173,4 +177,11 @@ public class TestStreamEvaluatorFactoryTest {
     assertThat(fifthResult.getWatermarkHold(), 
equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE));
     assertThat(fifthResult.getUnprocessedElements(), Matchers.emptyIterable());
   }
+
+  @Test
+  public void overrideFactoryGetInputSucceeds() {
+    DirectTestStreamFactory<?> factory = new DirectTestStreamFactory<>();
+    PBegin begin = factory.getInput(Collections.<TaggedPValue>emptyList(), p);
+    assertThat(begin.getPipeline(), Matchers.<Pipeline>equalTo(p));
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java
 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java
index 6baf55a..5b03bcd 100644
--- 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java
+++ 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java
@@ -18,6 +18,7 @@
 package org.apache.beam.runners.direct;
 
 import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.nullValue;
 import static org.junit.Assert.assertThat;
 import static org.mockito.Mockito.mock;
@@ -26,6 +27,7 @@ import static org.mockito.Mockito.when;
 import com.google.common.collect.ImmutableList;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
 import org.apache.beam.runners.direct.DirectRunner.PCollectionViewWriter;
+import org.apache.beam.runners.direct.ViewEvaluatorFactory.ViewOverrideFactory;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VoidCoder;
@@ -93,6 +95,13 @@ public class ViewEvaluatorFactoryTest {
             WindowedValue.valueInGlobalWindow("foo"), 
WindowedValue.valueInGlobalWindow("bar")));
   }
 
+  @Test
+  public void overrideFactoryGetInputSucceeds() {
+    ViewOverrideFactory<String, String> factory = new ViewOverrideFactory<>();
+    PCollection<String> input = p.apply(Create.of("foo", "bar"));
+    assertThat(factory.getInput(input.expand(), p), equalTo(input));
+  }
+
   private static class TestViewWriter<ElemT, ViewT> implements 
PCollectionViewWriter<ElemT, ViewT> {
     private Iterable<WindowedValue<ElemT>> latest;
 

http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
index 7432e61..0196a2d 100644
--- 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
+++ 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
@@ -54,7 +54,9 @@ import org.apache.beam.sdk.util.IOChannelUtils;
 import org.apache.beam.sdk.util.PCollectionViews;
 import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
+import org.hamcrest.Matchers;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
@@ -257,6 +259,13 @@ public class WriteWithShardingFactoryTest {
     assertThat(maxKey, equalTo(12L));
   }
 
+  @Test
+  public void getInputSucceeds() {
+    PCollection<String> original = p.apply(Create.of("foo"));
+    PCollection<?> input = factory.getInput(original.expand(), p);
+    assertThat(input, Matchers.<PCollection<?>>equalTo(original));
+  }
+
   private static class TestSink extends Sink<Object> {
     @Override
     public void validate(PipelineOptions options) {}

http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
index f6e90e2..1d9be66 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
@@ -19,11 +19,14 @@
 
 package org.apache.beam.sdk.runners;
 
+import java.util.List;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.TaggedPValue;
 
 /**
  * Produces {@link PipelineRunner}-specific overrides of {@link PTransform 
PTransforms}, and
@@ -38,4 +41,9 @@ public interface PTransformOverrideFactory<
    * Returns a {@link PTransform} that produces equivalent output to the 
provided transform.
    */
   PTransform<InputT, OutputT> getReplacementTransform(TransformT transform);
+
+  /**
+   * Returns the composite type that replacement transforms consumed from an 
equivalent expansion.
+   */
+  InputT getInput(List<TaggedPValue> inputs, Pipeline p);
 }

Reply via email to