Repository: beam Updated Branches: refs/heads/master cf9d2211f -> 64102943f
Add Input Reconstruction to PTransformOverrideFactory Inputs are only ever provided as expanded representations. Overrides, however, may be applied to compressed inputs. Add a method to reconstruct the language-specific composite input. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/078a2ff5 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/078a2ff5 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/078a2ff5 Branch: refs/heads/master Commit: 078a2ff54ecaa0d7d66b2a42fd135a5325722958 Parents: cf9d221 Author: Thomas Groh <[email protected]> Authored: Tue Feb 7 09:42:32 2017 -0800 Committer: Thomas Groh <[email protected]> Committed: Wed Feb 8 16:18:30 2017 -0800 ---------------------------------------------------------------------- ...ectGBKIntoKeyedWorkItemsOverrideFactory.java | 10 ++++ .../direct/DirectGroupByKeyOverrideFactory.java | 10 ++++ .../direct/ParDoMultiOverrideFactory.java | 10 ++++ .../ParDoSingleViaMultiOverrideFactory.java | 12 ++++- .../direct/TestStreamEvaluatorFactory.java | 7 +++ .../runners/direct/ViewEvaluatorFactory.java | 8 +++ .../direct/WriteWithShardingFactory.java | 10 ++++ .../DirectGroupByKeyOverrideFactoryTest.java | 51 ++++++++++++++++++++ .../direct/ParDoMultiOverrideFactoryTest.java | 45 +++++++++++++++++ .../ParDoSingleViaMultiOverrideFactoryTest.java | 45 +++++++++++++++++ .../direct/TestStreamEvaluatorFactoryTest.java | 11 +++++ .../direct/ViewEvaluatorFactoryTest.java | 9 ++++ .../direct/WriteWithShardingFactoryTest.java | 9 ++++ .../sdk/runners/PTransformOverrideFactory.java | 8 +++ 14 files changed, 244 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java index ab4c114..caf61db 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java @@ -17,12 +17,16 @@ */ package org.apache.beam.runners.direct; +import com.google.common.collect.Iterables; +import java.util.List; import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.SplittableParDo.GBKIntoKeyedWorkItems; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TaggedPValue; /** * Provides an implementation of {@link SplittableParDo.GBKIntoKeyedWorkItems} for the Direct @@ -37,4 +41,10 @@ class DirectGBKIntoKeyedWorkItemsOverrideFactory<KeyT, InputT> getReplacementTransform(GBKIntoKeyedWorkItems<KeyT, InputT> transform) { return new DirectGroupByKey.DirectGroupByKeyOnly<>(); } + + @Override + public PCollection<KV<KeyT, InputT>> getInput( + List<TaggedPValue> inputs, Pipeline p) { + return (PCollection<KV<KeyT, InputT>>) Iterables.getOnlyElement(inputs).getValue(); + } } http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java index 7cf3256..8a5413b 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java @@ -17,11 +17,15 @@ */ package org.apache.beam.runners.direct; +import com.google.common.collect.Iterables; +import java.util.List; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TaggedPValue; /** A {@link PTransformOverrideFactory} for {@link GroupByKey} PTransforms. */ final class DirectGroupByKeyOverrideFactory<K, V> @@ -32,4 +36,10 @@ final class DirectGroupByKeyOverrideFactory<K, V> GroupByKey<K, V> transform) { return new DirectGroupByKey<>(transform); } + + @Override + public PCollection<KV<K, V>> getInput( + List<TaggedPValue> inputs, Pipeline p) { + return (PCollection<KV<K, V>>) Iterables.getOnlyElement(inputs).getValue(); + } } http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java index ceb35ec..483b7ce 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java @@ -19,10 +19,13 @@ package org.apache.beam.runners.direct; import static com.google.common.base.Preconditions.checkState; +import com.google.common.collect.Iterables; +import java.util.List; import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.KeyedWorkItemCoder; import org.apache.beam.runners.core.KeyedWorkItems; import org.apache.beam.runners.core.SplittableParDo; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; @@ -44,6 +47,7 @@ import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TaggedPValue; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.TypedPValue; @@ -77,6 +81,12 @@ class ParDoMultiOverrideFactory<InputT, OutputT> } } + @Override + public PCollection<? extends InputT> getInput( + List<TaggedPValue> inputs, Pipeline p) { + return (PCollection<? extends InputT>) Iterables.getOnlyElement(inputs).getValue(); + } + static class GbkThenStatefulParDo<K, InputT, OutputT> extends PTransform<PCollection<KV<K, InputT>>, PCollectionTuple> { private final ParDo.BoundMulti<KV<K, InputT>, OutputT> underlyingParDo; http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java index 3ae3382..6da5bb4 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java @@ -17,12 +17,16 @@ */ package org.apache.beam.runners.direct; +import com.google.common.collect.Iterables; +import java.util.List; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.ParDo.Bound; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TaggedPValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; @@ -32,13 +36,19 @@ import org.apache.beam.sdk.values.TupleTagList; */ class ParDoSingleViaMultiOverrideFactory<InputT, OutputT> implements PTransformOverrideFactory< - PCollection<? extends InputT>, PCollection<OutputT>, Bound<InputT, OutputT>>{ + PCollection<? extends InputT>, PCollection<OutputT>, Bound<InputT, OutputT>> { @Override public PTransform<PCollection<? extends InputT>, PCollection<OutputT>> getReplacementTransform( Bound<InputT, OutputT> transform) { return new ParDoSingleViaMulti<>(transform); } + @Override + public PCollection<? extends InputT> getInput( + List<TaggedPValue> inputs, Pipeline p) { + return (PCollection<? extends InputT>) Iterables.getOnlyElement(inputs).getValue(); + } + static class ParDoSingleViaMulti<InputT, OutputT> extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { private static final String MAIN_OUTPUT_TAG = "main"; http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java index bdf293f..b81d7d5 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java @@ -31,6 +31,7 @@ import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.testing.TestStream; @@ -47,6 +48,7 @@ import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; +import org.apache.beam.sdk.values.TaggedPValue; import org.apache.beam.sdk.values.TimestampedValue; import org.joda.time.Duration; import org.joda.time.Instant; @@ -168,6 +170,11 @@ class TestStreamEvaluatorFactory implements TransformEvaluatorFactory { return new DirectTestStream<>(transform); } + @Override + public PBegin getInput(List<TaggedPValue> inputs, Pipeline p) { + return p.begin(); + } + static class DirectTestStream<T> extends PTransform<PBegin, PCollection<T>> { private final TestStream<T> original; http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java index fcd8423..817fb33 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java @@ -23,6 +23,7 @@ import java.util.List; import org.apache.beam.runners.direct.CommittedResult.OutputType; import org.apache.beam.runners.direct.DirectRunner.PCollectionViewWriter; import org.apache.beam.runners.direct.StepTransformResult.Builder; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.runners.PTransformOverrideFactory; @@ -35,6 +36,7 @@ import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TaggedPValue; /** * The {@link DirectRunner} {@link TransformEvaluatorFactory} for the @@ -105,6 +107,12 @@ class ViewEvaluatorFactory implements TransformEvaluatorFactory { CreatePCollectionView<ElemT, ViewT> transform) { return new DirectCreatePCollectionView<>(transform); } + + @Override + public PCollection<ElemT> getInput( + List<TaggedPValue> inputs, Pipeline p) { + return (PCollection<ElemT>) Iterables.getOnlyElement(inputs).getValue(); + } } /** http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java index fd1c175..9f5f4bd 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java @@ -21,7 +21,10 @@ package org.apache.beam.runners.direct; import static com.google.common.base.Preconditions.checkArgument; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Iterables; +import java.util.List; import java.util.concurrent.ThreadLocalRandom; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.Write; import org.apache.beam.sdk.io.Write.Bound; import org.apache.beam.sdk.transforms.Count; @@ -39,6 +42,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.TaggedPValue; import org.joda.time.Duration; /** @@ -60,6 +64,12 @@ class WriteWithShardingFactory<InputT> return transform; } + @Override + public PCollection<InputT> getInput( + List<TaggedPValue> inputs, Pipeline p) { + return (PCollection<InputT>) Iterables.getOnlyElement(inputs).getValue(); + } + private static class DynamicallyReshardedWrite<T> extends PTransform<PCollection<T>, PDone> { private final transient Write.Bound<T> original; http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java new file mode 100644 index 0000000..03f1dda --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.direct; + +import static org.junit.Assert.assertThat; + +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.hamcrest.Matchers; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link DirectGBKIntoKeyedWorkItemsOverrideFactory}. + */ +@RunWith(JUnit4.class) +public class DirectGroupByKeyOverrideFactoryTest { + private DirectGroupByKeyOverrideFactory factory = new DirectGroupByKeyOverrideFactory(); + @Test + public void getInputSucceeds() { + TestPipeline p = TestPipeline.create(); + PCollection<KV<String, Integer>> input = + p.apply( + Create.of(KV.of("foo", 1)) + .withCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))); + PCollection<?> reconstructed = factory.getInput(input.expand(), p); + assertThat(reconstructed, Matchers.<PCollection<?>>equalTo(input)); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java new file mode 100644 index 0000000..4bbf924 --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.direct; + +import static org.junit.Assert.assertThat; + +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.hamcrest.Matchers; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link ParDoMultiOverrideFactory}. + */ +@RunWith(JUnit4.class) +public class ParDoMultiOverrideFactoryTest { + private ParDoMultiOverrideFactory factory = new ParDoMultiOverrideFactory(); + + @Test + public void getInputSucceeds() { + TestPipeline p = TestPipeline.create(); + PCollection<Integer> input = p.apply(Create.of(1, 2, 3)); + PCollection<?> reconstructed = factory.getInput(input.expand(), p); + assertThat(reconstructed, Matchers.<PCollection<?>>equalTo(input)); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java new file mode 100644 index 0000000..8f170dd --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.direct; + +import static org.junit.Assert.assertThat; + +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.hamcrest.Matchers; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link ParDoSingleViaMultiOverrideFactory}. + */ +@RunWith(JUnit4.class) +public class ParDoSingleViaMultiOverrideFactoryTest { + private ParDoSingleViaMultiOverrideFactory factory = new ParDoSingleViaMultiOverrideFactory(); + + @Test + public void getInputSucceeds() { + TestPipeline p = TestPipeline.create(); + PCollection<Integer> input = p.apply(Create.of(1, 2, 3)); + PCollection<?> reconstructed = factory.getInput(input.expand(), p); + assertThat(reconstructed, Matchers.<PCollection<?>>equalTo(input)); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java index c5b3b3d..4dc7738 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java @@ -27,15 +27,19 @@ import com.google.common.collect.Iterables; import java.util.Collection; import java.util.Collections; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; +import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory; import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.TestClock; import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.TestStreamIndex; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TaggedPValue; import org.apache.beam.sdk.values.TimestampedValue; import org.hamcrest.Matchers; import org.joda.time.Duration; @@ -173,4 +177,11 @@ public class TestStreamEvaluatorFactoryTest { assertThat(fifthResult.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE)); assertThat(fifthResult.getUnprocessedElements(), Matchers.emptyIterable()); } + + @Test + public void overrideFactoryGetInputSucceeds() { + DirectTestStreamFactory<?> factory = new DirectTestStreamFactory<>(); + PBegin begin = factory.getInput(Collections.<TaggedPValue>emptyList(), p); + assertThat(begin.getPipeline(), Matchers.<Pipeline>equalTo(p)); + } } http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java index 6baf55a..5b03bcd 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.direct; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; @@ -26,6 +27,7 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.runners.direct.DirectRunner.PCollectionViewWriter; +import org.apache.beam.runners.direct.ViewEvaluatorFactory.ViewOverrideFactory; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VoidCoder; @@ -93,6 +95,13 @@ public class ViewEvaluatorFactoryTest { WindowedValue.valueInGlobalWindow("foo"), WindowedValue.valueInGlobalWindow("bar"))); } + @Test + public void overrideFactoryGetInputSucceeds() { + ViewOverrideFactory<String, String> factory = new ViewOverrideFactory<>(); + PCollection<String> input = p.apply(Create.of("foo", "bar")); + assertThat(factory.getInput(input.expand(), p), equalTo(input)); + } + private static class TestViewWriter<ElemT, ViewT> implements PCollectionViewWriter<ElemT, ViewT> { private Iterable<WindowedValue<ElemT>> latest; http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java index 7432e61..0196a2d 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java @@ -54,7 +54,9 @@ import org.apache.beam.sdk.util.IOChannelUtils; import org.apache.beam.sdk.util.PCollectionViews; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; +import org.hamcrest.Matchers; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -257,6 +259,13 @@ public class WriteWithShardingFactoryTest { assertThat(maxKey, equalTo(12L)); } + @Test + public void getInputSucceeds() { + PCollection<String> original = p.apply(Create.of("foo")); + PCollection<?> input = factory.getInput(original.expand(), p); + assertThat(input, Matchers.<PCollection<?>>equalTo(original)); + } + private static class TestSink extends Sink<Object> { @Override public void validate(PipelineOptions options) {} http://git-wip-us.apache.org/repos/asf/beam/blob/078a2ff5/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java index f6e90e2..1d9be66 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java @@ -19,11 +19,14 @@ package org.apache.beam.sdk.runners; +import java.util.List; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.TaggedPValue; /** * Produces {@link PipelineRunner}-specific overrides of {@link PTransform PTransforms}, and @@ -38,4 +41,9 @@ public interface PTransformOverrideFactory< * Returns a {@link PTransform} that produces equivalent output to the provided transform. */ PTransform<InputT, OutputT> getReplacementTransform(TransformT transform); + + /** + * Returns the composite type that replacement transforms consumed from an equivalent expansion. + */ + InputT getInput(List<TaggedPValue> inputs, Pipeline p); }
