http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 684dc14..4eec6b8 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -61,6 +61,7 @@ import java.util.TreeSet; import org.apache.beam.runners.core.construction.DeduplicatedFlattenFactory; import org.apache.beam.runners.core.construction.EmptyFlattenAsCreateFactory; import org.apache.beam.runners.core.construction.PTransformMatchers; +import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.ReplacementOutputs; import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; import org.apache.beam.runners.core.construction.UnboundedReadFromBoundedSource; @@ -96,6 +97,7 @@ import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.runners.TransformHierarchy.Node; import org.apache.beam.sdk.transforms.Aggregator; +import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Combine.GroupedValues; import org.apache.beam.sdk.transforms.DoFn; @@ -390,25 +392,29 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { } private static class ReflectiveOneToOneOverrideFactory< - InputT extends PValue, - OutputT extends PValue, - TransformT extends PTransform<InputT, OutputT>> - extends SingleInputOutputOverrideFactory<InputT, OutputT, TransformT> { - private final Class<PTransform<InputT, OutputT>> replacement; + InputT, OutputT, TransformT extends PTransform<PCollection<InputT>, PCollection<OutputT>>> + extends SingleInputOutputOverrideFactory< + PCollection<InputT>, PCollection<OutputT>, TransformT> { + private final Class<PTransform<PCollection<InputT>, PCollection<OutputT>>> replacement; private final DataflowRunner runner; private ReflectiveOneToOneOverrideFactory( - Class<PTransform<InputT, OutputT>> replacement, DataflowRunner runner) { + Class<PTransform<PCollection<InputT>, PCollection<OutputT>>> replacement, + DataflowRunner runner) { this.replacement = replacement; this.runner = runner; } @Override - public PTransform<InputT, OutputT> getReplacementTransform(TransformT transform) { - return InstanceBuilder.ofType(replacement) - .withArg(DataflowRunner.class, runner) - .withArg((Class<PTransform<InputT, OutputT>>) transform.getClass(), transform) - .build(); + public PTransformReplacement<PCollection<InputT>, PCollection<OutputT>> getReplacementTransform( + AppliedPTransform<PCollection<InputT>, PCollection<OutputT>, TransformT> transform) { + PTransform<PCollection<InputT>, PCollection<OutputT>> rep = + InstanceBuilder.ofType(replacement) + .withArg(DataflowRunner.class, runner) + .withArg( + (Class<TransformT>) transform.getTransform().getClass(), transform.getTransform()) + .build(); + return PTransformReplacement.of(PTransformReplacements.getSingletonMainInput(transform), rep); } } @@ -423,19 +429,18 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { this.replacement = replacement; this.runner = runner; } - @Override - public PTransform<PBegin, PCollection<T>> getReplacementTransform( - PTransform<PInput, PCollection<T>> transform) { - return InstanceBuilder.ofType(replacement) - .withArg(DataflowRunner.class, runner) - .withArg( - (Class<? super PTransform<PInput, PCollection<T>>>) transform.getClass(), transform) - .build(); - } @Override - public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { - return p.begin(); + public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform( + AppliedPTransform<PBegin, PCollection<T>, PTransform<PInput, PCollection<T>>> transform) { + PTransform<PInput, PCollection<T>> original = transform.getTransform(); + return PTransformReplacement.of( + transform.getPipeline().begin(), + InstanceBuilder.ofType(replacement) + .withArg(DataflowRunner.class, runner) + .withArg( + (Class<? super PTransform<PInput, PCollection<T>>>) original.getClass(), original) + .build()); } @Override @@ -805,13 +810,11 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { } @Override - public PTransform<PCollection<T>, PDone> getReplacementTransform(Write<T> transform) { - return new BatchWrite<>(runner, transform); - } - - @Override - public PCollection<T> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { - return (PCollection<T>) Iterables.getOnlyElement(inputs.values()); + public PTransformReplacement<PCollection<T>, PDone> getReplacementTransform( + AppliedPTransform<PCollection<T>, PDone, Write<T>> transform) { + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), + new BatchWrite<>(runner, transform.getTransform())); } @Override @@ -1295,15 +1298,15 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K, OutputT>>, Combine.GroupedValues<K, InputT, OutputT>> { @Override - public PTransform<PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K, OutputT>>> - getReplacementTransform(GroupedValues<K, InputT, OutputT> transform) { - return new CombineGroupedValues<>(transform); - } - - @Override - public PCollection<KV<K, Iterable<InputT>>> getInput( - Map<TupleTag<?>, PValue> inputs, Pipeline p) { - return (PCollection<KV<K, Iterable<InputT>>>) Iterables.getOnlyElement(inputs.values()); + public PTransformReplacement<PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K, OutputT>>> + getReplacementTransform( + AppliedPTransform< + PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K, OutputT>>, + GroupedValues<K, InputT, OutputT>> + transform) { + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), + new CombineGroupedValues<>(transform.getTransform())); } @Override @@ -1322,14 +1325,11 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { } @Override - public PTransform<PCollection<T>, PDone> getReplacementTransform( - PubsubUnboundedSink<T> transform) { - return new StreamingPubsubIOWrite<>(runner, transform); - } - - @Override - public PCollection<T> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { - return (PCollection<T>) Iterables.getOnlyElement(inputs.values()); + public PTransformReplacement<PCollection<T>, PDone> getReplacementTransform( + AppliedPTransform<PCollection<T>, PDone, PubsubUnboundedSink<T>> transform) { + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), + new StreamingPubsubIOWrite<>(runner, transform.getTransform())); } @Override
http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java index db50cc2..2e50cb5 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java @@ -20,12 +20,15 @@ package org.apache.beam.runners.dataflow; import java.util.List; import org.apache.beam.runners.core.construction.ForwardingPTransform; +import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; import org.apache.beam.sdk.common.runner.v1.RunnerApi.DisplayData; import org.apache.beam.sdk.runners.PTransformOverrideFactory; +import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.ParDo.SingleOutput; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; @@ -38,9 +41,15 @@ public class PrimitiveParDoSingleFactory<InputT, OutputT> extends SingleInputOutputOverrideFactory< PCollection<? extends InputT>, PCollection<OutputT>, ParDo.SingleOutput<InputT, OutputT>> { @Override - public PTransform<PCollection<? extends InputT>, PCollection<OutputT>> getReplacementTransform( - ParDo.SingleOutput<InputT, OutputT> transform) { - return new ParDoSingle<>(transform); + public PTransformReplacement<PCollection<? extends InputT>, PCollection<OutputT>> + getReplacementTransform( + AppliedPTransform< + PCollection<? extends InputT>, PCollection<OutputT>, + SingleOutput<InputT, OutputT>> + transform) { + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), + new ParDoSingle<>(transform.getTransform())); } /** http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java index 2e6455d..aa9d9f8 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java @@ -18,8 +18,10 @@ package org.apache.beam.runners.dataflow; +import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; import org.apache.beam.sdk.runners.PTransformOverrideFactory; +import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; @@ -43,9 +45,13 @@ class ReshuffleOverrideFactory<K, V> extends SingleInputOutputOverrideFactory< PCollection<KV<K, V>>, PCollection<KV<K, V>>, Reshuffle<K, V>> { @Override - public PTransform<PCollection<KV<K, V>>, PCollection<KV<K, V>>> getReplacementTransform( - Reshuffle<K, V> transform) { - return new ReshuffleWithOnlyTrigger<>(); + public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<K, V>>> + getReplacementTransform( + AppliedPTransform<PCollection<KV<K, V>>, PCollection<KV<K, V>>, Reshuffle<K, V>> + transform) { + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), + new ReshuffleWithOnlyTrigger<K, V>()); } private static class ReshuffleWithOnlyTrigger<K, V> http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java index c407517..eb385de 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java @@ -20,11 +20,13 @@ package org.apache.beam.runners.dataflow; import java.util.ArrayList; import java.util.List; +import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; import org.apache.beam.runners.dataflow.DataflowRunner.StreamingPCollectionViewWriterFn; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.PTransform; @@ -42,9 +44,15 @@ class StreamingViewOverrides { extends SingleInputOutputOverrideFactory< PCollection<ElemT>, PCollectionView<ViewT>, CreatePCollectionView<ElemT, ViewT>> { @Override - public PTransform<PCollection<ElemT>, PCollectionView<ViewT>> getReplacementTransform( - final CreatePCollectionView<ElemT, ViewT> transform) { - return new StreamingCreatePCollectionView<>(transform.getView()); + public PTransformReplacement<PCollection<ElemT>, PCollectionView<ViewT>> + getReplacementTransform( + AppliedPTransform< + PCollection<ElemT>, PCollectionView<ViewT>, CreatePCollectionView<ElemT, ViewT>> + transform) { + StreamingCreatePCollectionView<ElemT, ViewT> streamingView = + new StreamingCreatePCollectionView<>(transform.getTransform().getView()); + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), streamingView); } private static class StreamingCreatePCollectionView<ElemT, ViewT> http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java index bff46ea..e320036 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java @@ -27,10 +27,11 @@ import java.io.Serializable; import java.util.List; import org.apache.beam.runners.dataflow.PrimitiveParDoSingleFactory.ParDoSingle; import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.View; @@ -64,17 +65,27 @@ public class PrimitiveParDoSingleFactoryTest implements Serializable { public void getReplacementTransformPopulateDisplayData() { ParDo.SingleOutput<Integer, Long> originalTransform = ParDo.of(new ToLongFn()); DisplayData originalDisplayData = DisplayData.from(originalTransform); - - PTransform<PCollection<? extends Integer>, PCollection<Long>> replacement = - factory.getReplacementTransform(originalTransform); - DisplayData replacementDisplayData = DisplayData.from(replacement); + PCollection<? extends Integer> input = pipeline.apply(Create.of(1, 2, 3)); + AppliedPTransform< + PCollection<? extends Integer>, PCollection<Long>, ParDo.SingleOutput<Integer, Long>> + application = + AppliedPTransform.of( + "original", + input.expand(), + input.apply(originalTransform).expand(), + originalTransform, + pipeline); + + PTransformReplacement<PCollection<? extends Integer>, PCollection<Long>> replacement = + factory.getReplacementTransform(application); + DisplayData replacementDisplayData = DisplayData.from(replacement.getTransform()); assertThat(replacementDisplayData, equalTo(originalDisplayData)); DisplayData primitiveDisplayData = Iterables.getOnlyElement( DisplayDataEvaluator.create() - .displayDataForPrimitiveTransforms(replacement, VarIntCoder.of())); + .displayDataForPrimitiveTransforms(replacement.getTransform(), VarIntCoder.of())); assertThat(primitiveDisplayData, equalTo(replacementDisplayData)); } @@ -91,9 +102,21 @@ public class PrimitiveParDoSingleFactoryTest implements Serializable { ParDo.SingleOutput<Integer, Long> originalTransform = ParDo.of(new ToLongFn()).withSideInputs(sideLong, sideStrings); - PTransform<PCollection<? extends Integer>, PCollection<Long>> replacementTransform = - factory.getReplacementTransform(originalTransform); - ParDoSingle<Integer, Long> parDoSingle = (ParDoSingle<Integer, Long>) replacementTransform; + PCollection<? extends Integer> input = pipeline.apply(Create.of(1, 2, 3)); + AppliedPTransform< + PCollection<? extends Integer>, PCollection<Long>, ParDo.SingleOutput<Integer, Long>> + application = + AppliedPTransform.of( + "original", + input.expand(), + input.apply(originalTransform).expand(), + originalTransform, + pipeline); + + PTransformReplacement<PCollection<? extends Integer>, PCollection<Long>> replacementTransform = + factory.getReplacementTransform(application); + ParDoSingle<Integer, Long> parDoSingle = + (ParDoSingle<Integer, Long>) replacementTransform.getTransform(); assertThat(parDoSingle.getSideInputs(), containsInAnyOrder(sideStrings, sideLong)); } @@ -101,9 +124,21 @@ public class PrimitiveParDoSingleFactoryTest implements Serializable { public void getReplacementTransformGetFn() { DoFn<Integer, Long> originalFn = new ToLongFn(); ParDo.SingleOutput<Integer, Long> originalTransform = ParDo.of(originalFn); - PTransform<PCollection<? extends Integer>, PCollection<Long>> replacementTransform = - factory.getReplacementTransform(originalTransform); - ParDoSingle<Integer, Long> parDoSingle = (ParDoSingle<Integer, Long>) replacementTransform; + PCollection<? extends Integer> input = pipeline.apply(Create.of(1, 2, 3)); + AppliedPTransform< + PCollection<? extends Integer>, PCollection<Long>, ParDo.SingleOutput<Integer, Long>> + application = + AppliedPTransform.of( + "original", + input.expand(), + input.apply(originalTransform).expand(), + originalTransform, + pipeline); + + PTransformReplacement<PCollection<? extends Integer>, PCollection<Long>> replacementTransform = + factory.getReplacementTransform(application); + ParDoSingle<Integer, Long> parDoSingle = + (ParDoSingle<Integer, Long>) replacementTransform.getTransform(); assertThat(parDoSingle.getFn(), equalTo(originalTransform.getFn())); assertThat(parDoSingle.getFn(), equalTo(originalFn)); http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java index aacb942..61fcaa9 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java @@ -46,6 +46,7 @@ import org.apache.beam.sdk.runners.PTransformOverride; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.util.ValueWithRecordId; @@ -244,14 +245,11 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { implements PTransformOverrideFactory< PBegin, PCollection<T>, BoundedReadFromUnboundedSource<T>> { @Override - public PTransform<PBegin, PCollection<T>> getReplacementTransform( - BoundedReadFromUnboundedSource<T> transform) { - return new AdaptedBoundedAsUnbounded<>(transform); - } - - @Override - public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { - return p.begin(); + public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform( + AppliedPTransform<PBegin, PCollection<T>, BoundedReadFromUnboundedSource<T>> transform) { + return PTransformReplacement.of( + transform.getPipeline().begin(), + new AdaptedBoundedAsUnbounded<T>(transform.getTransform())); } @Override http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java index 791166e..1ff4c30 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java @@ -33,11 +33,13 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.runners.PTransformOverride; import org.apache.beam.sdk.runners.PTransformOverrideFactory; +import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement; import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput; import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.runners.TransformHierarchy.Node; import org.apache.beam.sdk.transforms.Aggregator; +import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.UserCodeException; @@ -497,17 +499,18 @@ public class Pipeline { void applyReplacement( Node original, PTransformOverrideFactory<InputT, OutputT, TransformT> replacementFactory) { - PTransform<InputT, OutputT> replacement = - replacementFactory.getReplacementTransform((TransformT) original.getTransform()); - if (replacement == original.getTransform()) { + PTransformReplacement<InputT, OutputT> replacement = + replacementFactory.getReplacementTransform( + (AppliedPTransform<InputT, OutputT, TransformT>) original.toAppliedPTransform()); + if (replacement.getTransform() == original.getTransform()) { return; } - InputT originalInput = replacementFactory.getInput(original.getInputs(), this); + InputT originalInput = replacement.getInput(); LOG.debug("Replacing {} with {}", original, replacement); - transforms.replaceNode(original, originalInput, replacement); + transforms.replaceNode(original, originalInput, replacement.getTransform()); try { - OutputT newOutput = replacement.expand(originalInput); + OutputT newOutput = replacement.getTransform().expand(originalInput); Map<PValue, ReplacementOutput> originalToReplacement = replacementFactory.mapOutputs(original.getOutputs(), newOutput); // Ensure the internal TransformHierarchy data structures are consistent. http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java index 57cba50..786c61c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java @@ -21,9 +21,9 @@ package org.apache.beam.sdk.runners; import com.google.auto.value.AutoValue; import java.util.Map; -import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; @@ -41,14 +41,11 @@ public interface PTransformOverrideFactory< OutputT extends POutput, TransformT extends PTransform<? super InputT, OutputT>> { /** - * Returns a {@link PTransform} that produces equivalent output to the provided transform. + * Returns a {@link PTransform} that produces equivalent output to the provided {@link + * AppliedPTransform transform}. */ - PTransform<InputT, OutputT> getReplacementTransform(TransformT transform); - - /** - * Returns the composite type that replacement transforms consumed from an equivalent expansion. - */ - InputT getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p); + PTransformReplacement<InputT, OutputT> getReplacementTransform( + AppliedPTransform<InputT, OutputT, TransformT> transform); /** * Returns a {@link Map} from the expanded values in {@code newOutput} to the values produced by @@ -56,7 +53,25 @@ public interface PTransformOverrideFactory< */ Map<PValue, ReplacementOutput> mapOutputs(Map<TupleTag<?>, PValue> outputs, OutputT newOutput); - /** A mapping between original {@link TaggedPValue} outputs and their replacements. */ + /** + * A {@link PTransform} that replaces an {@link AppliedPTransform}, and the input required to + * do so. The input must be constructed from the expanded form, as the transform may not have + * originally been applied within this process or from within a Java SDK. + */ + @AutoValue + abstract class PTransformReplacement<InputT extends PInput, OutputT extends POutput> { + public static <InputT extends PInput, OutputT extends POutput> + PTransformReplacement<InputT, OutputT> of( + InputT input, PTransform<InputT, OutputT> transform) { + return new AutoValue_PTransformOverrideFactory_PTransformReplacement(input, transform); + } + public abstract InputT getInput(); + public abstract PTransform<InputT, OutputT> getTransform(); + } + + /** + * A mapping between original {@link TaggedPValue} outputs and their replacements. + */ @AutoValue abstract class ReplacementOutput { public static ReplacementOutput of(TaggedPValue original, TaggedPValue replacement) { http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java index 8d99a62..bdb61b8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java @@ -31,6 +31,11 @@ import org.apache.beam.sdk.values.TupleTag; * * <p>For internal use. * + * <p>Inputs and outputs are stored in their expanded forms, as the condensed form of a composite + * {@link PInput} or {@link POutput} is a language-specific concept, and {@link AppliedPTransform} + * represents a possibly cross-language transform for which no appropriate composite type exists + * in the Java SDK. + * * @param <InputT> transform input type * @param <OutputT> transform output type * @param <TransformT> transform type http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java index 6ce016d..75cabf2 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java @@ -406,16 +406,10 @@ public class PipelineTest { class ReplacementOverrideFactory implements PTransformOverrideFactory< PCollection<String>, PCollection<Long>, OriginalTransform> { - @Override - public PTransform<PCollection<String>, PCollection<Long>> getReplacementTransform( - OriginalTransform transform) { - return new ReplacementTransform(); - } - - @Override - public PCollection<String> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { - return originalInput; + public PTransformReplacement<PCollection<String>, PCollection<Long>> getReplacementTransform( + AppliedPTransform<PCollection<String>, PCollection<Long>, OriginalTransform> transform) { + return PTransformReplacement.of(originalInput, new ReplacementTransform()); } @Override @@ -464,14 +458,9 @@ public class PipelineTest { static class BoundedCountingInputOverride implements PTransformOverrideFactory<PBegin, PCollection<Long>, BoundedCountingInput> { @Override - public PTransform<PBegin, PCollection<Long>> getReplacementTransform( - BoundedCountingInput transform) { - return Create.of(0L); - } - - @Override - public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { - return p.begin(); + public PTransformReplacement<PBegin, PCollection<Long>> getReplacementTransform( + AppliedPTransform<PBegin, PCollection<Long>, BoundedCountingInput> transform) { + return PTransformReplacement.of(transform.getPipeline().begin(), Create.of(0L)); } @Override @@ -489,15 +478,11 @@ public class PipelineTest { } static class UnboundedCountingInputOverride implements PTransformOverrideFactory<PBegin, PCollection<Long>, UnboundedCountingInput> { - @Override - public PTransform<PBegin, PCollection<Long>> getReplacementTransform( - UnboundedCountingInput transform) { - return CountingInput.upTo(100L); - } @Override - public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { - return p.begin(); + public PTransformReplacement<PBegin, PCollection<Long>> getReplacementTransform( + AppliedPTransform<PBegin, PCollection<Long>, UnboundedCountingInput> transform) { + return PTransformReplacement.of(transform.getPipeline().begin(), CountingInput.upTo(100L)); } @Override