Update PipelineTest.testReplacedNames

Validate that the node has been replaced (via comparing the class of a
subnode) rather than just checking names.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/1e3bee18
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/1e3bee18
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/1e3bee18

Branch: refs/heads/master
Commit: 1e3bee189b0e4368604816a2c7df600c86233a20
Parents: de7cc05
Author: Uri Silberstein <uri.silberst...@gmail.com>
Authored: Mon Oct 2 16:27:13 2017 +0300
Committer: Thomas Groh <tg...@google.com>
Committed: Wed Oct 18 10:00:00 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/beam/sdk/PipelineTest.java  | 51 +++++++++++---------
 1 file changed, 27 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/1e3bee18/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java
index 2cc3f04..57fdd75 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java
@@ -30,10 +30,9 @@ import static org.junit.Assert.fail;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
 import java.util.Collections;
-import java.util.HashSet;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
 import org.apache.beam.sdk.Pipeline.PipelineVisitor;
 import org.apache.beam.sdk.io.GenerateSequence;
@@ -51,12 +50,13 @@ import org.apache.beam.sdk.testing.NeedsRunner;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.testing.ValidatesRunner;
-import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.Max;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.util.UserCodeException;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
@@ -394,39 +394,40 @@ public class PipelineTest {
   }
 
   @Test
-  public void testReplacedNames() {
+  public void testReplaceWithExistingName() {
     pipeline.enableAbandonedNodeEnforcement(false);
-    final PCollection<String> originalInput = pipeline.apply(Create.of("foo", 
"bar", "baz"));
-    class OriginalTransform extends PTransform<PCollection<String>, 
PCollection<Long>> {
+    final PCollection<Integer> originalInput = pipeline.apply(Create.of(1, 2, 
3));
+    class OriginalTransform extends PTransform<PCollection<Integer>, 
PCollection<Integer>> {
       @Override
-      public PCollection<Long> expand(PCollection<String> input) {
-        return input.apply("custom_name", Count.<String>globally());
+      public PCollection<Integer> expand(PCollection<Integer> input) {
+        return input.apply("custom_name", Sum.integersGlobally());
       }
     }
-    class ReplacementTransform extends PTransform<PCollection<String>, 
PCollection<Long>> {
+    class ReplacementTransform extends PTransform<PCollection<Integer>, 
PCollection<Integer>> {
       @Override
-      public PCollection<Long> expand(PCollection<String> input) {
-        return input.apply("custom_name", Count.<String>globally());
+      public PCollection<Integer> expand(PCollection<Integer> input) {
+        return input.apply("custom_name", Max.integersGlobally());
       }
     }
     class ReplacementOverrideFactory
         implements PTransformOverrideFactory<
-            PCollection<String>, PCollection<Long>, OriginalTransform> {
-      @Override
-      public PTransformReplacement<PCollection<String>, PCollection<Long>> 
getReplacementTransform(
-          AppliedPTransform<PCollection<String>, PCollection<Long>, 
OriginalTransform> transform) {
+            PCollection<Integer>, PCollection<Integer>, OriginalTransform> {
+
+      @Override public PTransformReplacement<PCollection<Integer>, 
PCollection<Integer>>
+      getReplacementTransform(
+          AppliedPTransform<PCollection<Integer>,
+              PCollection<Integer>, OriginalTransform> transform) {
         return PTransformReplacement.of(originalInput, new 
ReplacementTransform());
       }
 
       @Override
       public Map<PValue, ReplacementOutput> mapOutputs(
-          Map<TupleTag<?>, PValue> outputs, PCollection<Long> newOutput) {
+          Map<TupleTag<?>, PValue> outputs, PCollection<Integer> newOutput) {
         return Collections.<PValue, ReplacementOutput>singletonMap(
             newOutput,
             ReplacementOutput.of(
-                TaggedPValue.ofExpandedValue(
-                    Iterables.getOnlyElement(outputs.values())),
-                    TaggedPValue.ofExpandedValue(newOutput)));
+                
TaggedPValue.ofExpandedValue(Iterables.getOnlyElement(outputs.values())),
+                TaggedPValue.ofExpandedValue(newOutput)));
       }
     }
 
@@ -441,24 +442,26 @@ public class PipelineTest {
     pipeline.replaceAll(
         Collections.singletonList(
             PTransformOverride.of(new OriginalMatcher(), new 
ReplacementOverrideFactory())));
-    final Set<String> names = new HashSet<>();
+    final Map<String, Class<?>> nameToTransformClass = new HashMap<>();
     pipeline.traverseTopologically(
         new PipelineVisitor.Defaults() {
           @Override
           public void leaveCompositeTransform(Node node) {
             if (!node.isRootNode()) {
-              names.add(node.getFullName());
+              nameToTransformClass.put(node.getFullName(), 
node.getTransform().getClass());
             }
           }
 
           @Override
           public void visitPrimitiveTransform(Node node) {
-            names.add(node.getFullName());
+            nameToTransformClass.put(node.getFullName(), 
node.getTransform().getClass());
           }
         });
 
-    assertThat(names, hasItem("original_application/custom_name"));
-    assertThat(names, not(hasItem("original_application/custom_name2")));
+    assertThat(nameToTransformClass.keySet(), 
hasItem("original_application/custom_name"));
+    assertThat(nameToTransformClass.keySet(), 
not(hasItem("original_application/custom_name2")));
+    
Assert.assertEquals(nameToTransformClass.get("original_application/custom_name"),
+        Max.integersGlobally().getClass());
   }
 
   static class GenerateSequenceToCreateOverride

Reply via email to