Update PipelineTest.testReplacedNames Validate that the node has been replaced (via comparing the class of a subnode) rather than just checking names.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/1e3bee18 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/1e3bee18 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/1e3bee18 Branch: refs/heads/master Commit: 1e3bee189b0e4368604816a2c7df600c86233a20 Parents: de7cc05 Author: Uri Silberstein <uri.silberst...@gmail.com> Authored: Mon Oct 2 16:27:13 2017 +0300 Committer: Thomas Groh <tg...@google.com> Committed: Wed Oct 18 10:00:00 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/beam/sdk/PipelineTest.java | 51 +++++++++++--------- 1 file changed, 27 insertions(+), 24 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/1e3bee18/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java index 2cc3f04..57fdd75 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java @@ -30,10 +30,9 @@ import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import java.util.Collections; -import java.util.HashSet; +import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.io.GenerateSequence; @@ -51,12 +50,13 @@ import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.ValidatesRunner; -import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.Max; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -394,39 +394,40 @@ public class PipelineTest { } @Test - public void testReplacedNames() { + public void testReplaceWithExistingName() { pipeline.enableAbandonedNodeEnforcement(false); - final PCollection<String> originalInput = pipeline.apply(Create.of("foo", "bar", "baz")); - class OriginalTransform extends PTransform<PCollection<String>, PCollection<Long>> { + final PCollection<Integer> originalInput = pipeline.apply(Create.of(1, 2, 3)); + class OriginalTransform extends PTransform<PCollection<Integer>, PCollection<Integer>> { @Override - public PCollection<Long> expand(PCollection<String> input) { - return input.apply("custom_name", Count.<String>globally()); + public PCollection<Integer> expand(PCollection<Integer> input) { + return input.apply("custom_name", Sum.integersGlobally()); } } - class ReplacementTransform extends PTransform<PCollection<String>, PCollection<Long>> { + class ReplacementTransform extends PTransform<PCollection<Integer>, PCollection<Integer>> { @Override - public PCollection<Long> expand(PCollection<String> input) { - return input.apply("custom_name", Count.<String>globally()); + public PCollection<Integer> expand(PCollection<Integer> input) { + return input.apply("custom_name", Max.integersGlobally()); } } class ReplacementOverrideFactory implements PTransformOverrideFactory< - PCollection<String>, PCollection<Long>, OriginalTransform> { - @Override - public PTransformReplacement<PCollection<String>, PCollection<Long>> getReplacementTransform( - AppliedPTransform<PCollection<String>, PCollection<Long>, OriginalTransform> transform) { + PCollection<Integer>, PCollection<Integer>, OriginalTransform> { + + @Override public PTransformReplacement<PCollection<Integer>, PCollection<Integer>> + getReplacementTransform( + AppliedPTransform<PCollection<Integer>, + PCollection<Integer>, OriginalTransform> transform) { return PTransformReplacement.of(originalInput, new ReplacementTransform()); } @Override public Map<PValue, ReplacementOutput> mapOutputs( - Map<TupleTag<?>, PValue> outputs, PCollection<Long> newOutput) { + Map<TupleTag<?>, PValue> outputs, PCollection<Integer> newOutput) { return Collections.<PValue, ReplacementOutput>singletonMap( newOutput, ReplacementOutput.of( - TaggedPValue.ofExpandedValue( - Iterables.getOnlyElement(outputs.values())), - TaggedPValue.ofExpandedValue(newOutput))); + TaggedPValue.ofExpandedValue(Iterables.getOnlyElement(outputs.values())), + TaggedPValue.ofExpandedValue(newOutput))); } } @@ -441,24 +442,26 @@ public class PipelineTest { pipeline.replaceAll( Collections.singletonList( PTransformOverride.of(new OriginalMatcher(), new ReplacementOverrideFactory()))); - final Set<String> names = new HashSet<>(); + final Map<String, Class<?>> nameToTransformClass = new HashMap<>(); pipeline.traverseTopologically( new PipelineVisitor.Defaults() { @Override public void leaveCompositeTransform(Node node) { if (!node.isRootNode()) { - names.add(node.getFullName()); + nameToTransformClass.put(node.getFullName(), node.getTransform().getClass()); } } @Override public void visitPrimitiveTransform(Node node) { - names.add(node.getFullName()); + nameToTransformClass.put(node.getFullName(), node.getTransform().getClass()); } }); - assertThat(names, hasItem("original_application/custom_name")); - assertThat(names, not(hasItem("original_application/custom_name2"))); + assertThat(nameToTransformClass.keySet(), hasItem("original_application/custom_name")); + assertThat(nameToTransformClass.keySet(), not(hasItem("original_application/custom_name2"))); + Assert.assertEquals(nameToTransformClass.get("original_application/custom_name"), + Max.integersGlobally().getClass()); } static class GenerateSequenceToCreateOverride