Free PTransform Names if they are being Replaced Naming is based on what's in the graph, not what once was there.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/79f0d114 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/79f0d114 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/79f0d114 Branch: refs/heads/master Commit: 79f0d114b75752024ca41038b649f72a0882dabd Parents: 6bf4262 Author: Thomas Groh <[email protected]> Authored: Tue Apr 11 18:17:02 2017 -0700 Committer: Thomas Groh <[email protected]> Committed: Fri Apr 14 10:39:52 2017 -0700 ---------------------------------------------------------------------- .../main/java/org/apache/beam/sdk/Pipeline.java | 24 ++++-- .../java/org/apache/beam/sdk/PipelineTest.java | 77 ++++++++++++++++++++ 2 files changed, 93 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/79f0d114/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java index 11d781d..791166e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java @@ -30,7 +30,6 @@ import java.util.Set; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptions.CheckEnabled; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.runners.PTransformOverride; import org.apache.beam.sdk.runners.PTransformOverrideFactory; @@ -229,26 +228,38 @@ public class Pipeline { } private void replace(final PTransformOverride override) { - final Collection<Node> matches = new ArrayList<>(); + final Set<Node> matches = new HashSet<>(); + final Set<Node> freedNodes = new HashSet<>(); transforms.visit( new PipelineVisitor.Defaults() { @Override public CompositeBehavior enterCompositeTransform(Node node) { + if (!node.isRootNode() && freedNodes.contains(node.getEnclosingNode())) { + // This node will be freed because its parent will be freed. + freedNodes.add(node); + return CompositeBehavior.ENTER_TRANSFORM; + } if (!node.isRootNode() && override.getMatcher().matches(node.toAppliedPTransform())) { matches.add(node); - // This node will be replaced. It should not be visited. - return CompositeBehavior.DO_NOT_ENTER_TRANSFORM; + // This node will be freed. When we visit any of its children, they will also be freed + freedNodes.add(node); } return CompositeBehavior.ENTER_TRANSFORM; } @Override public void visitPrimitiveTransform(Node node) { - if (override.getMatcher().matches(node.toAppliedPTransform())) { + if (freedNodes.contains(node.getEnclosingNode())) { + freedNodes.add(node); + } else if (override.getMatcher().matches(node.toAppliedPTransform())) { matches.add(node); + freedNodes.add(node); } } }); + for (Node freedNode : freedNodes) { + usedFullNames.remove(freedNode.getFullName()); + } for (Node match : matches) { applyReplacement(match, override.getOverrideFactory()); } @@ -486,9 +497,6 @@ public class Pipeline { void applyReplacement( Node original, PTransformOverrideFactory<InputT, OutputT, TransformT> replacementFactory) { - // Names for top-level transforms have been assigned. Any new collisions are within a node - // and its replacement. - getOptions().setStableUniqueNames(CheckEnabled.OFF); PTransform<InputT, OutputT> replacement = replacementFactory.getReplacementTransform((TransformT) original.getTransform()); if (replacement == original.getTransform()) { http://git-wip-us.apache.org/repos/asf/beam/blob/79f0d114/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java index 0a5746b..6ce016d 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java @@ -19,6 +19,7 @@ package org.apache.beam.sdk; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.isA; import static org.hamcrest.Matchers.not; @@ -29,8 +30,10 @@ import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.io.CountingInput; @@ -51,6 +54,7 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.AppliedPTransform; +import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.MapElements; @@ -384,6 +388,79 @@ public class PipelineTest { new UnboundedCountingInputOverride()))); } + @Test + public void testReplacedNames() { + final PCollection<String> originalInput = pipeline.apply(Create.of("foo", "bar", "baz")); + class OriginalTransform extends PTransform<PCollection<String>, PCollection<Long>> { + @Override + public PCollection<Long> expand(PCollection<String> input) { + return input.apply("custom_name", Count.<String>globally()); + } + } + class ReplacementTransform extends PTransform<PCollection<String>, PCollection<Long>> { + @Override + public PCollection<Long> expand(PCollection<String> input) { + return input.apply("custom_name", Count.<String>globally()); + } + } + class ReplacementOverrideFactory + implements PTransformOverrideFactory< + PCollection<String>, PCollection<Long>, OriginalTransform> { + + @Override + public PTransform<PCollection<String>, PCollection<Long>> getReplacementTransform( + OriginalTransform transform) { + return new ReplacementTransform(); + } + + @Override + public PCollection<String> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { + return originalInput; + } + + @Override + public Map<PValue, ReplacementOutput> mapOutputs( + Map<TupleTag<?>, PValue> outputs, PCollection<Long> newOutput) { + return Collections.<PValue, ReplacementOutput>singletonMap( + newOutput, + ReplacementOutput.of( + TaggedPValue.ofExpandedValue( + Iterables.getOnlyElement(outputs.values())), + TaggedPValue.ofExpandedValue(newOutput))); + } + } + + class OriginalMatcher implements PTransformMatcher { + @Override + public boolean matches(AppliedPTransform<?, ?, ?> application) { + return application.getTransform() instanceof OriginalTransform; + } + } + + originalInput.apply("original_application", new OriginalTransform()); + pipeline.replaceAll( + Collections.singletonList( + PTransformOverride.of(new OriginalMatcher(), new ReplacementOverrideFactory()))); + final Set<String> names = new HashSet<>(); + pipeline.traverseTopologically( + new PipelineVisitor.Defaults() { + @Override + public void leaveCompositeTransform(Node node) { + if (!node.isRootNode()) { + names.add(node.getFullName()); + } + } + + @Override + public void visitPrimitiveTransform(Node node) { + names.add(node.getFullName()); + } + }); + + assertThat(names, hasItem("original_application/custom_name")); + assertThat(names, not(hasItem("original_application/custom_name2"))); + } + static class BoundedCountingInputOverride implements PTransformOverrideFactory<PBegin, PCollection<Long>, BoundedCountingInput> { @Override
