This is an automated email from the ASF dual-hosted git repository. janl pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 377f1ac [BEAM-8550] @RequiresTimeSortedInput: working with legacy flink and spark new 041f7af Merge pull request #8774 from je-ik/requires-time-sorted-input-draft: [BEAM-8550] Requires time sorted input 377f1ac is described below commit 377f1ac7ebbc4253299e7efbdb3ad58d0c9e14c5 Author: Jan Lukavsky <je...@seznam.cz> AuthorDate: Thu Jan 30 13:10:31 2020 +0100 [BEAM-8550] @RequiresTimeSortedInput: working with legacy flink and spark --- .gitignore | 1 + .../pipeline/src/main/proto/beam_runner_api.proto | 4 +- .../translation/operators/ApexParDoOperator.java | 8 +- .../core/construction/ParDoTranslation.java | 8 + .../runners/core/construction/SplittableParDo.java | 5 + .../org/apache/beam/runners/core/DoFnRunners.java | 60 ++++- .../apache/beam/runners/core/SimpleDoFnRunner.java | 4 +- .../beam/runners/core/StatefulDoFnRunner.java | 172 ++++++++++--- .../SimplePushbackSideInputDoFnRunnerTest.java | 26 +- .../beam/runners/core/StatefulDoFnRunnerTest.java | 285 ++++++++++++++++++--- .../apache/beam/runners/direct/ParDoEvaluator.java | 32 ++- .../runners/direct/ParDoMultiOverrideFactory.java | 73 +++--- .../beam/runners/direct/QuiescenceDriver.java | 2 +- .../direct/StatefulParDoEvaluatorFactoryTest.java | 149 +++++++++-- .../FlinkBatchPortablePipelineTranslator.java | 11 + .../flink/FlinkBatchTransformTranslators.java | 28 +- .../FlinkStreamingPortablePipelineTranslator.java | 1 - .../flink/FlinkStreamingTransformTranslators.java | 8 - .../utils/FlinkPortableRunnerUtils.java | 58 +++++ .../wrappers/streaming/DoFnOperator.java | 59 +++-- .../streaming/ExecutableStageDoFnOperator.java | 114 +++++---- .../wrappers/streaming/SplittableDoFnOperator.java | 6 +- .../wrappers/streaming/WindowDoFnOperator.java | 4 +- .../runners/flink/FlinkPipelineOptionsTest.java | 2 - .../wrappers/streaming/DoFnOperatorTest.java | 21 -- .../streaming/ExecutableStageDoFnOperatorTest.java | 6 +- .../dataflow/PrimitiveParDoSingleFactory.java | 5 + .../runners/samza/runtime/SamzaDoFnRunners.java | 5 +- .../beam/runners/spark/coders/CoderHelpers.java | 47 ++++ .../spark/translation/TransformTranslator.java | 175 +++++++++++-- .../spark/translation/TransformTranslatorTest.java | 106 ++++++++ .../apache/beam/sdk/runners/AppliedPTransform.java | 9 + .../sdk/testing/UsesRequiresTimeSortedInput.java | 27 ++ .../java/org/apache/beam/sdk/transforms/DoFn.java | 27 ++ .../beam/sdk/transforms/reflect/DoFnSignature.java | 8 + .../sdk/transforms/reflect/DoFnSignatures.java | 2 + .../org/apache/beam/sdk/transforms/ParDoTest.java | 191 +++++++++++++- 37 files changed, 1492 insertions(+), 257 deletions(-) diff --git a/.gitignore b/.gitignore index 5732b9c..f030006 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ # Ignore files generated by the Gradle build process. **/.gradle/**/* **/.gogradle/**/* +**/.nb-gradle/**/* **/gogradle.lock **/build/**/* .test-infra/**/vendor/**/* diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto index 57c5295..81e4d2d 100644 --- a/model/pipeline/src/main/proto/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/beam_runner_api.proto @@ -175,7 +175,6 @@ message StandardPTransforms { enum Primitives { // Represents Beam's parallel do operation. // Payload: ParDoPayload. - // TODO(BEAM-3595): Change this to beam:transform:pardo:v1. PAR_DO = 0 [(beam_urn) = "beam:transform:pardo:v1"]; // Represents Beam's flatten operation. @@ -398,6 +397,9 @@ message ParDoPayload { // (Optional) A mapping of local timer family names to timer specifications. map<string, TimerFamilySpec> timer_family_specs = 9; + + // Whether this stage requires time sorted input + bool requires_time_sorted_input = 10; } // Parameters that a UDF might require. diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java index 4841c6a..8df7997 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java @@ -511,7 +511,13 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator doFnRunner = DoFnRunners.defaultStatefulDoFnRunner( - doFn, doFnRunner, windowingStrategy, cleanupTimer, stateCleaner); + doFn, + inputCoder, + doFnRunner, + stepContext, + windowingStrategy, + cleanupTimer, + stateCleaner); } pushbackDoFnRunner = diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java index 7e6ba7b..19a272a 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java @@ -284,6 +284,11 @@ public class ParDoTranslation { } @Override + public boolean isRequiresTimeSortedInput() { + return signature.processElement().requiresTimeSortedInput(); + } + + @Override public String translateRestrictionCoderId(SdkComponents newComponents) { return restrictionCoderId; } @@ -756,6 +761,8 @@ public class ParDoTranslation { boolean isSplittable(); + boolean isRequiresTimeSortedInput(); + String translateRestrictionCoderId(SdkComponents newComponents); } @@ -770,6 +777,7 @@ public class ParDoTranslation { .putAllTimerFamilySpecs(parDo.translateTimerFamilySpecs(components)) .putAllSideInputs(parDo.translateSideInputs(components)) .setSplittable(parDo.isSplittable()) + .setRequiresTimeSortedInput(parDo.isRequiresTimeSortedInput()) .setRestrictionCoderId(parDo.translateRestrictionCoderId(components)) .build(); } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java index 2b700d5..a48c3a8 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java @@ -412,6 +412,11 @@ public class SplittableParDo<InputT, OutputT, RestrictionT> } @Override + public boolean isRequiresTimeSortedInput() { + return false; + } + + @Override public String translateRestrictionCoderId(SdkComponents newComponents) { return restrictionCoderId; } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java index 6496561..88ba954 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java @@ -28,6 +28,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; @@ -58,7 +59,7 @@ public class DoFnRunners { TupleTag<OutputT> mainOutputTag, List<TupleTag<?>> additionalOutputTags, StepContext stepContext, - @Nullable Coder<InputT> inputCoder, + Coder<InputT> inputCoder, Map<TupleTag<?>, Coder<?>> outputCoders, WindowingStrategy<?, ?> windowingStrategy, DoFnSchemaInformation doFnSchemaInformation, @@ -95,16 +96,69 @@ public class DoFnRunners { * Returns an implementation of {@link DoFnRunner} that handles late data dropping and garbage * collection for stateful {@link DoFn DoFns}. * - * <p>It registers a timer by TimeInternals, and clean all states by StateInternals. + * <p>It registers a timer by TimeInternals, and clean all states by StateInternals. It also + * correctly handles {@link DoFn.RequiresTimeSortedInput} if the provided {@link DoFn} requires + * this. */ public static <InputT, OutputT, W extends BoundedWindow> DoFnRunner<InputT, OutputT> defaultStatefulDoFnRunner( DoFn<InputT, OutputT> fn, + Coder<InputT> inputCoder, DoFnRunner<InputT, OutputT> doFnRunner, + StepContext stepContext, WindowingStrategy<?, ?> windowingStrategy, CleanupTimer<InputT> cleanupTimer, StateCleaner<W> stateCleaner) { - return new StatefulDoFnRunner<>(doFnRunner, windowingStrategy, cleanupTimer, stateCleaner); + + return defaultStatefulDoFnRunner( + fn, + inputCoder, + doFnRunner, + stepContext, + windowingStrategy, + cleanupTimer, + stateCleaner, + false); + } + + /** + * Returns an implementation of {@link DoFnRunner} that handles late data dropping and garbage + * collection for stateful {@link DoFn DoFns}. + * + * <p>It registers a timer by TimeInternals, and clean all states by StateInternals. If {@code + * requiresTimeSortedInputSupported} is {@code true} then it also handles {@link + * DoFn.RequiresTimeSortedInput} if the provided {@link DoFn} requires this. If {@code + * requiresTimeSortedInputSupported} is {@code false} and the provided {@link DoFn} has {@link + * DoFn.RequiresTimeSortedInput} this method will throw {@link UnsupportedOperationException}. + */ + public static <InputT, OutputT, W extends BoundedWindow> + DoFnRunner<InputT, OutputT> defaultStatefulDoFnRunner( + DoFn<InputT, OutputT> fn, + Coder<InputT> inputCoder, + DoFnRunner<InputT, OutputT> doFnRunner, + StepContext stepContext, + WindowingStrategy<?, ?> windowingStrategy, + CleanupTimer<InputT> cleanupTimer, + StateCleaner<W> stateCleaner, + boolean requiresTimeSortedInputSupported) { + + boolean doFnRequiresTimeSortedInput = + DoFnSignatures.signatureForDoFn(doFnRunner.getFn()) + .processElement() + .requiresTimeSortedInput(); + + if (doFnRequiresTimeSortedInput && !requiresTimeSortedInputSupported) { + throw new UnsupportedOperationException( + "DoFn.RequiresTimeSortedInput not currently supported by this runner."); + } + return new StatefulDoFnRunner<>( + doFnRunner, + inputCoder, + stepContext, + windowingStrategy, + cleanupTimer, + stateCleaner, + doFnRequiresTimeSortedInput); } public static <InputT, OutputT, RestrictionT> diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java index 0b41602..10a13fd 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java @@ -131,9 +131,7 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out this.invoker = DoFnInvokers.invokerFor(fn); this.sideInputReader = sideInputReader; this.schemaCoder = - (inputCoder != null && inputCoder instanceof SchemaCoder) - ? (SchemaCoder<InputT>) inputCoder - : null; + (inputCoder instanceof SchemaCoder) ? (SchemaCoder<InputT>) inputCoder : null; this.outputCoders = outputCoders; if (outputCoders != null && !outputCoders.isEmpty()) { Coder<OutputT> outputCoder = (Coder<OutputT>) outputCoders.get(mainOutputTag); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StatefulDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StatefulDoFnRunner.java index f6170ce..63cee4c 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StatefulDoFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StatefulDoFnRunner.java @@ -17,22 +17,32 @@ */ package org.apache.beam.runners.core; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; import java.util.Map; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.InstantCoder; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.NonMergingWindowFn; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.WindowTracing; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; import org.joda.time.Instant; /** @@ -47,24 +57,49 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow> implements DoFnRunner<InputT, OutputT> { public static final String DROPPED_DUE_TO_LATENESS_COUNTER = "StatefulParDoDropped"; + private static final String SORT_BUFFER_STATE = "sortBuffer"; + private static final String SORT_BUFFER_MIN_STAMP = "sortBufferMinStamp"; + private static final String SORT_FLUSH_TIMER = "__StatefulParDoSortFlushTimerId"; + private static final String SORT_FLUSH_WATERMARK_HOLD = "flushWatermarkHold"; private final DoFnRunner<InputT, OutputT> doFnRunner; + private final StepContext stepContext; private final WindowingStrategy<?, ?> windowingStrategy; private final Counter droppedDueToLateness = Metrics.counter(StatefulDoFnRunner.class, DROPPED_DUE_TO_LATENESS_COUNTER); private final CleanupTimer<InputT> cleanupTimer; private final StateCleaner stateCleaner; + private final boolean requiresTimeSortedInput; + private final Coder<BoundedWindow> windowCoder; + private final StateTag<BagState<WindowedValue<InputT>>> sortBufferTag; + private final StateTag<ValueState<Instant>> sortBufferMinStampTag = + StateTags.makeSystemTagInternal(StateTags.value(SORT_BUFFER_MIN_STAMP, InstantCoder.of())); + private final StateTag<WatermarkHoldState> watermarkHold = + StateTags.watermarkStateInternal(SORT_FLUSH_WATERMARK_HOLD, TimestampCombiner.LATEST); public StatefulDoFnRunner( DoFnRunner<InputT, OutputT> doFnRunner, + Coder<InputT> inputCoder, + StepContext stepContext, WindowingStrategy<?, ?> windowingStrategy, CleanupTimer<InputT> cleanupTimer, - StateCleaner<W> stateCleaner) { + StateCleaner<W> stateCleaner, + boolean requiresTimeSortedInput) { this.doFnRunner = doFnRunner; + this.stepContext = stepContext; this.windowingStrategy = windowingStrategy; this.cleanupTimer = cleanupTimer; this.stateCleaner = stateCleaner; + this.requiresTimeSortedInput = requiresTimeSortedInput; WindowFn<?, ?> windowFn = windowingStrategy.getWindowFn(); + @SuppressWarnings("unchecked") + Coder<BoundedWindow> untypedCoder = (Coder<BoundedWindow>) windowFn.windowCoder(); + this.windowCoder = untypedCoder; + + this.sortBufferTag = + StateTags.makeSystemTagInternal( + StateTags.bag(SORT_BUFFER_STATE, WindowedValue.getFullCoder(inputCoder, windowCoder))); + rejectMergingWindowFn(windowFn); } @@ -75,6 +110,10 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow> } } + public List<StateTag<?>> getSystemStateTags() { + return Arrays.asList(sortBufferTag, sortBufferMinStampTag, watermarkHold); + } + @Override public DoFn<InputT, OutputT> getFn() { return doFnRunner.getFn(); @@ -86,35 +125,76 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow> } @Override + public void finishBundle() { + doFnRunner.finishBundle(); + } + + @Override public void processElement(WindowedValue<InputT> input) { // StatefulDoFnRunner always observes windows, so we need to explode for (WindowedValue<InputT> value : input.explodeWindows()) { - BoundedWindow window = value.getWindows().iterator().next(); - if (isLate(window)) { // The element is too late for this window. - droppedDueToLateness.inc(); - WindowTracing.debug( - "StatefulDoFnRunner.processElement: Dropping element at {}; window:{} " - + "since too far behind inputWatermark:{}", - input.getTimestamp(), - window, - cleanupTimer.currentInputWatermarkTime()); + reportDroppedElement(value, window); + } else if (requiresTimeSortedInput) { + processElementOrdered(window, value); } else { - cleanupTimer.setForWindow(value.getValue(), window); - doFnRunner.processElement(value); + processElementUnordered(window, value); } } } + private void processElementUnordered(BoundedWindow window, WindowedValue<InputT> value) { + cleanupTimer.setForWindow(value.getValue(), window); + doFnRunner.processElement(value); + } + + private void processElementOrdered(BoundedWindow window, WindowedValue<InputT> value) { + + StateInternals stateInternals = stepContext.stateInternals(); + TimerInternals timerInternals = stepContext.timerInternals(); + + Instant outputWatermark = + MoreObjects.firstNonNull( + timerInternals.currentOutputWatermarkTime(), BoundedWindow.TIMESTAMP_MIN_VALUE); + + if (!outputWatermark.isAfter( + value.getTimestamp().plus(windowingStrategy.getAllowedLateness()))) { + + StateNamespace namespace = StateNamespaces.window(windowCoder, window); + BagState<WindowedValue<InputT>> sortBuffer = stateInternals.state(namespace, sortBufferTag); + ValueState<Instant> minStampState = stateInternals.state(namespace, sortBufferMinStampTag); + sortBuffer.add(value); + Instant minStamp = + MoreObjects.firstNonNull(minStampState.read(), BoundedWindow.TIMESTAMP_MAX_VALUE); + if (value.getTimestamp().isBefore(minStamp)) { + minStamp = value.getTimestamp(); + minStampState.write(minStamp); + setupFlushTimerAndWatermarkHold(namespace, minStamp); + } + } else { + reportDroppedElement(value, window); + } + } + private boolean isLate(BoundedWindow window) { Instant gcTime = LateDataUtils.garbageCollectionTime(window, windowingStrategy); - Instant inputWM = cleanupTimer.currentInputWatermarkTime(); + Instant inputWM = stepContext.timerInternals().currentInputWatermarkTime(); return gcTime.isBefore(inputWM); } + private void reportDroppedElement(WindowedValue<InputT> value, BoundedWindow window) { + droppedDueToLateness.inc(); + WindowTracing.debug( + "StatefulDoFnRunner.processElement: Dropping element at {}; window:{} " + + "since too far behind inputWatermark:{}", + value.getTimestamp(), + window, + stepContext.timerInternals().currentInputWatermarkTime()); + } + @Override public void onTimer( String timerId, @@ -123,12 +203,14 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow> Instant timestamp, Instant outputTimestamp, TimeDomain timeDomain) { - if (cleanupTimer.isForWindow(timerId, window, timestamp, timeDomain)) { + if (timerId.equals(SORT_FLUSH_TIMER)) { + onSortFlushTimer(window, stepContext.timerInternals().currentInputWatermarkTime()); + } else if (cleanupTimer.isForWindow(timerId, window, timestamp, timeDomain)) { stateCleaner.clearForWindow(window); // There should invoke the onWindowExpiration of DoFn } else { // An event-time timer can never be late because we don't allow setting timers after GC time. - // Ot can happen that a processing-time time fires for a late window, we need to ignore + // It can happen that a processing-time timer fires for a late window, we need to ignore // this. if (!timeDomain.equals(TimeDomain.EVENT_TIME) && isLate(window)) { // don't increment the dropped counter, only do that for elements @@ -137,16 +219,57 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow> + "since window is too far behind inputWatermark:{}", timestamp, window, - cleanupTimer.currentInputWatermarkTime()); + stepContext.timerInternals().currentInputWatermarkTime()); } else { doFnRunner.onTimer(timerId, timerFamilyId, window, timestamp, outputTimestamp, timeDomain); } } } - @Override - public void finishBundle() { - doFnRunner.finishBundle(); + // this needs to be optimized (Sorted Map State) + private void onSortFlushTimer(BoundedWindow window, Instant timestamp) { + StateInternals stateInternals = stepContext.stateInternals(); + StateNamespace namespace = StateNamespaces.window(windowCoder, window); + BagState<WindowedValue<InputT>> sortBuffer = stateInternals.state(namespace, sortBufferTag); + ValueState<Instant> minStampState = stateInternals.state(namespace, sortBufferMinStampTag); + List<WindowedValue<InputT>> keep = new ArrayList<>(); + List<WindowedValue<InputT>> flush = new ArrayList<>(); + Instant newMinStamp = BoundedWindow.TIMESTAMP_MAX_VALUE; + for (WindowedValue<InputT> e : sortBuffer.read()) { + if (!e.getTimestamp().isAfter(timestamp)) { + flush.add(e); + } else { + keep.add(e); + if (e.getTimestamp().isBefore(newMinStamp)) { + newMinStamp = e.getTimestamp(); + } + } + } + flush.stream() + .sorted(Comparator.comparing(WindowedValue::getTimestamp)) + .forEachOrdered(e -> processElementUnordered(window, e)); + sortBuffer.clear(); + keep.forEach(sortBuffer::add); + minStampState.write(newMinStamp); + if (newMinStamp.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE)) { + setupFlushTimerAndWatermarkHold(namespace, newMinStamp); + } else { + clearWatermarkHold(namespace); + } + } + + private void setupFlushTimerAndWatermarkHold(StateNamespace namespace, Instant flush) { + WatermarkHoldState watermark = stepContext.stateInternals().state(namespace, watermarkHold); + stepContext + .timerInternals() + .setTimer( + namespace, SORT_FLUSH_TIMER, SORT_FLUSH_TIMER, flush, flush, TimeDomain.EVENT_TIME); + watermark.clear(); + watermark.add(flush); + } + + private void clearWatermarkHold(StateNamespace namespace) { + stepContext.stateInternals().state(namespace, watermarkHold).clear(); } /** @@ -158,12 +281,6 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow> */ public interface CleanupTimer<InputT> { - /** - * Return the current, local input watermark timestamp for this computation in the {@link - * TimeDomain#EVENT_TIME} time domain. - */ - Instant currentInputWatermarkTime(); - /** Set the garbage collect time of the window to timer. */ void setForWindow(InputT value, BoundedWindow window); @@ -203,11 +320,6 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow> } @Override - public Instant currentInputWatermarkTime() { - return timerInternals.currentInputWatermarkTime(); - } - - @Override public void setForWindow(InputT input, BoundedWindow window) { Instant gcTime = LateDataUtils.garbageCollectionTime(window, windowingStrategy); // make sure this fires after any window.maxTimestamp() timers diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SimplePushbackSideInputDoFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SimplePushbackSideInputDoFnRunnerTest.java index 8f88d41..16ba0b1 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SimplePushbackSideInputDoFnRunnerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SimplePushbackSideInputDoFnRunnerTest.java @@ -34,6 +34,8 @@ import java.util.List; import org.apache.beam.runners.core.TimerInternals.TimerData; import org.apache.beam.runners.core.metrics.MetricsContainerImpl; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.metrics.MetricsEnvironment; @@ -109,7 +111,7 @@ public class SimplePushbackSideInputDoFnRunnerTest { .apply(Window.into(new IdentitySideInputWindowFn())) .apply(Sum.integersGlobally().asSingletonView()); - underlying = new TestDoFnRunner<>(); + underlying = new TestDoFnRunner<>(VarIntCoder.of()); DoFn<KV<String, Integer>, Integer> fn = new MyDoFn(); @@ -125,13 +127,30 @@ public class SimplePushbackSideInputDoFnRunnerTest { statefulRunner = DoFnRunners.defaultStatefulDoFnRunner( fn, + KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()), getDoFnRunner(fn), + asStepContext(stateInternals, timerInternals), WINDOWING_STRATEGY, new StatefulDoFnRunner.TimeInternalsCleanupTimer(timerInternals, WINDOWING_STRATEGY), new StatefulDoFnRunner.StateInternalsStateCleaner<>( fn, stateInternals, (Coder) WINDOWING_STRATEGY.getWindowFn().windowCoder())); } + private StepContext asStepContext(StateInternals stateInternals, TimerInternals timerInternals) { + return new StepContext() { + + @Override + public StateInternals stateInternals() { + return stateInternals; + } + + @Override + public TimerInternals timerInternals() { + return timerInternals; + } + }; + } + private SimplePushbackSideInputDoFnRunner<Integer, Integer> createRunner( ImmutableList<PCollectionView<?>> views) { SimplePushbackSideInputDoFnRunner<Integer, Integer> runner = @@ -297,11 +316,16 @@ public class SimplePushbackSideInputDoFnRunnerTest { } private static class TestDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, OutputT> { + private final Coder<InputT> inputCoder; List<WindowedValue<InputT>> inputElems; List<TimerData> firedTimers; private boolean started = false; private boolean finished = false; + TestDoFnRunner(Coder<InputT> inputCoder) { + this.inputCoder = inputCoder; + } + @Override public DoFn<InputT, OutputT> getFn() { return null; diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java index be4e321..d9d9bad 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java @@ -22,9 +22,17 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.when; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; +import java.util.List; +import java.util.function.BiFunction; +import javax.annotation.Nullable; +import org.apache.beam.runners.core.DoFnRunners.OutputManager; import org.apache.beam.runners.core.metrics.MetricsContainerImpl; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.metrics.MetricsEnvironment; @@ -39,6 +47,7 @@ import org.apache.beam.sdk.transforms.windowing.IntervalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; import org.joda.time.Duration; @@ -67,6 +76,8 @@ public class StatefulDoFnRunnerTest { private static final IntervalWindow WINDOW_2 = new IntervalWindow(new Instant(10), new Instant(20)); + private final TupleTag<Integer> outputTag = new TupleTag<>(); + @Mock StepContext mockStepContext; private InMemoryStateInternals<String> stateInternals; @@ -79,7 +90,6 @@ public class StatefulDoFnRunnerTest { @Before public void setup() { MockitoAnnotations.initMocks(this); - when(mockStepContext.timerInternals()).thenReturn(timerInternals); stateInternals = new InMemoryStateInternals<>("hello"); timerInternals = new InMemoryTimerInternals(); @@ -89,23 +99,50 @@ public class StatefulDoFnRunnerTest { } @Test - public void testLateDropping() throws Exception { + public void testLateDroppingUnordered() throws Exception { + testLateDropping(false); + } + + @Test + public void testLateDroppingOrdered() throws Exception { + testLateDropping(true); + } + + @Test + public void testGargageCollectUnordered() throws Exception { + testGarbageCollect(false); + } + + @Test + public void testGargageCollectOrdered() throws Exception { + testGarbageCollect(true); + } + + @Test + public void testOutputUnordered() throws Exception { + testOutput(false); + } + + @Test + public void testOutputOrdered() throws Exception { + testOutput(true); + } + + @Test(expected = UnsupportedOperationException.class) + public void testOutputOrderedUnsupported() throws Exception { + testOutput(true, (fn, output) -> createStatefulDoFnRunner(fn, output, false)); + } + + private void testLateDropping(boolean ordered) throws Exception { MetricsContainerImpl container = new MetricsContainerImpl("any"); MetricsEnvironment.setCurrentContainer(container); timerInternals.advanceInputWatermark(new Instant(BoundedWindow.TIMESTAMP_MAX_VALUE)); timerInternals.advanceOutputWatermark(new Instant(BoundedWindow.TIMESTAMP_MAX_VALUE)); - DoFn<KV<String, Integer>, Integer> fn = new MyDoFn(); + MyDoFn fn = MyDoFn.create(ordered); - DoFnRunner<KV<String, Integer>, Integer> runner = - DoFnRunners.defaultStatefulDoFnRunner( - fn, - getDoFnRunner(fn), - WINDOWING_STRATEGY, - new StatefulDoFnRunner.TimeInternalsCleanupTimer(timerInternals, WINDOWING_STRATEGY), - new StatefulDoFnRunner.StateInternalsStateCleaner<>( - fn, stateInternals, (Coder) WINDOWING_STRATEGY.getWindowFn().windowCoder())); + DoFnRunner<KV<String, Integer>, Integer> runner = createStatefulDoFnRunner(fn); runner.startBundle(); @@ -126,21 +163,13 @@ public class StatefulDoFnRunnerTest { runner.finishBundle(); } - @Test - public void testGarbageCollect() throws Exception { + private void testGarbageCollect(boolean ordered) throws Exception { timerInternals.advanceInputWatermark(new Instant(1L)); - MyDoFn fn = new MyDoFn(); - StateTag<ValueState<Integer>> stateTag = StateTags.tagForSpec(fn.stateId, fn.intState); + MyDoFn fn = MyDoFn.create(ordered); + StateTag<ValueState<Integer>> stateTag = StateTags.tagForSpec(MyDoFn.STATE_ID, fn.intState()); - DoFnRunner<KV<String, Integer>, Integer> runner = - DoFnRunners.defaultStatefulDoFnRunner( - fn, - getDoFnRunner(fn), - WINDOWING_STRATEGY, - new StatefulDoFnRunner.TimeInternalsCleanupTimer(timerInternals, WINDOWING_STRATEGY), - new StatefulDoFnRunner.StateInternalsStateCleaner<>( - fn, stateInternals, (Coder) WINDOWING_STRATEGY.getWindowFn().windowCoder())); + DoFnRunner<KV<String, Integer>, Integer> runner = createStatefulDoFnRunner(fn); Instant elementTime = new Instant(1); @@ -148,6 +177,11 @@ public class StatefulDoFnRunnerTest { runner.processElement( WindowedValue.of(KV.of("hello", 1), elementTime, WINDOW_1, PaneInfo.NO_FIRING)); + if (ordered) { + // move forward in time so that the input might get flushed + advanceInputWatermark(timerInternals, elementTime.plus(1), runner); + } + assertEquals(1, (int) stateInternals.state(windowNamespace(WINDOW_1), stateTag).read()); // second element, key is hello, WINDOW_2 @@ -159,6 +193,11 @@ public class StatefulDoFnRunnerTest { WindowedValue.of( KV.of("hello", 1), elementTime.plus(WINDOW_SIZE), WINDOW_2, PaneInfo.NO_FIRING)); + if (ordered) { + // move forward in time to so that the input might get flushed + advanceInputWatermark(timerInternals, elementTime.plus(1 + WINDOW_SIZE), runner); + } + assertEquals(2, (int) stateInternals.state(windowNamespace(WINDOW_2), stateTag).read()); // advance watermark past end of WINDOW_1 + allowed lateness @@ -194,14 +233,139 @@ public class StatefulDoFnRunnerTest { stateInternals.state(windowNamespace(WINDOW_2), stateTag))); } + private void testOutput(boolean ordered) throws Exception { + testOutput(ordered, this::createStatefulDoFnRunner); + } + + private void testOutput( + boolean ordered, + BiFunction<MyDoFn, OutputManager, DoFnRunner<KV<String, Integer>, Integer>> runnerFactory) + throws Exception { + + timerInternals.advanceInputWatermark(new Instant(1L)); + + MyDoFn fn = MyDoFn.create(ordered); + StateTag<ValueState<Integer>> stateTag = StateTags.tagForSpec(MyDoFn.STATE_ID, fn.intState()); + + List<KV<TupleTag<?>, WindowedValue<?>>> outputs = new ArrayList<>(); + OutputManager output = asOutputManager(outputs); + DoFnRunner<KV<String, Integer>, Integer> runner = runnerFactory.apply(fn, output); + + Instant elementTime = new Instant(5); + + // write two elements, with descending timestamps + runner.processElement( + WindowedValue.of(KV.of("hello", 1), elementTime, WINDOW_1, PaneInfo.NO_FIRING)); + runner.processElement( + WindowedValue.of(KV.of("hello", 2), elementTime.minus(1), WINDOW_1, PaneInfo.NO_FIRING)); + + if (ordered) { + // move forward in time to so that the input might get flushed + advanceInputWatermark(timerInternals, elementTime.plus(1), runner); + } + + assertEquals(3, (int) stateInternals.state(windowNamespace(WINDOW_1), stateTag).read()); + assertEquals(2, outputs.size()); + if (ordered) { + assertEquals( + Arrays.asList( + KV.of( + outputTag, + WindowedValue.of(2, elementTime.minus(1), WINDOW_1, PaneInfo.NO_FIRING)), + KV.of(outputTag, WindowedValue.of(3, elementTime, WINDOW_1, PaneInfo.NO_FIRING))), + outputs); + } else { + assertEquals( + Arrays.asList( + KV.of(outputTag, WindowedValue.of(1, elementTime, WINDOW_1, PaneInfo.NO_FIRING)), + KV.of( + outputTag, + WindowedValue.of(3, elementTime.minus(1), WINDOW_1, PaneInfo.NO_FIRING))), + outputs); + } + outputs.clear(); + + // another window + elementTime = elementTime.plus(WINDOW_SIZE); + runner.processElement( + WindowedValue.of(KV.of("hello", 1), elementTime, WINDOW_2, PaneInfo.NO_FIRING)); + + runner.processElement( + WindowedValue.of(KV.of("hello", 2), elementTime.minus(1), WINDOW_2, PaneInfo.NO_FIRING)); + + runner.processElement( + WindowedValue.of(KV.of("hello", 3), elementTime.minus(2), WINDOW_2, PaneInfo.NO_FIRING)); + + if (ordered) { + // move forward in time to so that the input might get flushed + advanceInputWatermark(timerInternals, elementTime.plus(1), runner); + } + + assertEquals(6, (int) stateInternals.state(windowNamespace(WINDOW_2), stateTag).read()); + assertEquals(3, outputs.size()); + if (ordered) { + assertEquals( + Arrays.asList( + KV.of( + outputTag, + WindowedValue.of(3, elementTime.minus(2), WINDOW_2, PaneInfo.NO_FIRING)), + KV.of( + outputTag, + WindowedValue.of(5, elementTime.minus(1), WINDOW_2, PaneInfo.NO_FIRING)), + KV.of(outputTag, WindowedValue.of(6, elementTime, WINDOW_2, PaneInfo.NO_FIRING))), + outputs); + } else { + assertEquals( + Arrays.asList( + KV.of(outputTag, WindowedValue.of(1, elementTime, WINDOW_2, PaneInfo.NO_FIRING)), + KV.of( + outputTag, + WindowedValue.of(3, elementTime.minus(1), WINDOW_2, PaneInfo.NO_FIRING)), + KV.of( + outputTag, + WindowedValue.of(6, elementTime.minus(2), WINDOW_2, PaneInfo.NO_FIRING))), + outputs); + } + } + + private DoFnRunner createStatefulDoFnRunner(DoFn<KV<String, Integer>, Integer> fn) { + return createStatefulDoFnRunner(fn, null); + } + + private DoFnRunner createStatefulDoFnRunner( + DoFn<KV<String, Integer>, Integer> fn, OutputManager outputManager) { + return createStatefulDoFnRunner(fn, outputManager, true); + } + + private DoFnRunner createStatefulDoFnRunner( + DoFn<KV<String, Integer>, Integer> fn, + OutputManager outputManager, + boolean supportTimeSortedInput) { + return DoFnRunners.defaultStatefulDoFnRunner( + fn, + KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()), + getDoFnRunner(fn, outputManager), + mockStepContext, + WINDOWING_STRATEGY, + new StatefulDoFnRunner.TimeInternalsCleanupTimer(timerInternals, WINDOWING_STRATEGY), + new StatefulDoFnRunner.StateInternalsStateCleaner<>( + fn, stateInternals, (Coder) WINDOWING_STRATEGY.getWindowFn().windowCoder()), + supportTimeSortedInput); + } + private DoFnRunner<KV<String, Integer>, Integer> getDoFnRunner( DoFn<KV<String, Integer>, Integer> fn) { + return getDoFnRunner(fn, null); + } + + private DoFnRunner<KV<String, Integer>, Integer> getDoFnRunner( + DoFn<KV<String, Integer>, Integer> fn, @Nullable OutputManager outputManager) { return new SimpleDoFnRunner<>( null, fn, NullSideInputReader.empty(), - null, - null, + MoreObjects.firstNonNull(outputManager, discardingOutputManager()), + outputTag, Collections.emptyList(), mockStepContext, null, @@ -211,6 +375,15 @@ public class StatefulDoFnRunnerTest { Collections.emptyMap()); } + private OutputManager discardingOutputManager() { + return new OutputManager() { + @Override + public <T> void output(TupleTag<T> tag, WindowedValue<T> output) { + // discard + } + }; + } + private static void advanceInputWatermark( InMemoryTimerInternals timerInternals, Instant newInputWatermark, DoFnRunner<?, ?> toTrigger) throws Exception { @@ -230,17 +403,67 @@ public class StatefulDoFnRunnerTest { } } - private static class MyDoFn extends DoFn<KV<String, Integer>, Integer> { + private static OutputManager asOutputManager(List<KV<TupleTag<?>, WindowedValue<?>>> outputs) { + return new OutputManager() { + @Override + public <T> void output(TupleTag<T> tag, WindowedValue<T> output) { + outputs.add(KV.of(tag, output)); + } + }; + } + + private abstract static class MyDoFn extends DoFn<KV<String, Integer>, Integer> { + + static final String STATE_ID = "foo"; + + static MyDoFn create(boolean sorted) { + return sorted ? new MyDoFnSorted() : new MyDoFnUnsorted(); + } + + abstract StateSpec<ValueState<Integer>> intState(); + + public void processElement(ProcessContext c, ValueState<Integer> state) { + KV<String, Integer> elem = c.element(); + Integer currentValue = MoreObjects.firstNonNull(state.read(), 0); + int updated = currentValue + elem.getValue(); + state.write(updated); + c.output(updated); + } + } - public final String stateId = "foo"; + private static class MyDoFnUnsorted extends MyDoFn { - @StateId(stateId) + @StateId(STATE_ID) public final StateSpec<ValueState<Integer>> intState = StateSpecs.value(VarIntCoder.of()); @ProcessElement - public void processElement(ProcessContext c, @StateId(stateId) ValueState<Integer> state) { - Integer currentValue = MoreObjects.firstNonNull(state.read(), 0); - state.write(currentValue + 1); + @Override + public void processElement(ProcessContext c, @StateId(STATE_ID) ValueState<Integer> state) { + super.processElement(c, state); + } + + @Override + StateSpec<ValueState<Integer>> intState() { + return intState; + } + } + + private static class MyDoFnSorted extends MyDoFn { + + @StateId(STATE_ID) + public final StateSpec<ValueState<Integer>> intState = StateSpecs.value(VarIntCoder.of()); + + @RequiresTimeSortedInput + @ProcessElement + @Override + public void processElement( + ProcessContext c, @StateId(MyDoFn.STATE_ID) ValueState<Integer> state) { + super.processElement(c, state); + } + + @Override + StateSpec<ValueState<Integer>> intState() { + return intState; } } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java index 5f41175..7986fd1 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java @@ -23,13 +23,14 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; -import javax.annotation.Nullable; import org.apache.beam.runners.core.DoFnRunner; import org.apache.beam.runners.core.DoFnRunners; import org.apache.beam.runners.core.DoFnRunners.OutputManager; +import org.apache.beam.runners.core.KeyedWorkItemCoder; import org.apache.beam.runners.core.PushbackSideInputDoFnRunner; import org.apache.beam.runners.core.ReadyCheckingSideInputReader; import org.apache.beam.runners.core.SimplePushbackSideInputDoFnRunner; +import org.apache.beam.runners.core.StatefulDoFnRunner; import org.apache.beam.runners.core.TimerInternals.TimerData; import org.apache.beam.runners.direct.DirectExecutionContext.DirectStepContext; import org.apache.beam.runners.local.StructuralKey; @@ -38,6 +39,7 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; @@ -59,7 +61,7 @@ class ParDoEvaluator<InputT> implements TransformEvaluator<InputT> { TupleTag<OutputT> mainOutputTag, List<TupleTag<?>> additionalOutputTags, DirectStepContext stepContext, - @Nullable Coder<InputT> inputCoder, + Coder<InputT> inputCoder, Map<TupleTag<?>, Coder<?>> outputCoders, WindowingStrategy<?, ? extends BoundedWindow> windowingStrategy, DoFnSchemaInformation doFnSchemaInformation, @@ -75,7 +77,7 @@ class ParDoEvaluator<InputT> implements TransformEvaluator<InputT> { mainOutputTag, additionalOutputTags, stepContext, - schemaCoder, + inputCoder, outputCoders, windowingStrategy, doFnSchemaInformation, @@ -89,11 +91,33 @@ class ParDoEvaluator<InputT> implements TransformEvaluator<InputT> { mainOutputTag, additionalOutputTags, stepContext, - schemaCoder, + inputCoder, outputCoders, windowingStrategy, doFnSchemaInformation, sideInputMapping); + if (DoFnSignatures.signatureForDoFn(fn).usesState()) { + // the coder specified on the input PCollection doesn't match type + // of elements processed by the StatefulDoFnRunner + // that is internal detail of how DirectRunner processes stateful DoFns + @SuppressWarnings("unchecked") + final KeyedWorkItemCoder<?, InputT> keyedWorkItemCoder = + (KeyedWorkItemCoder<?, InputT>) inputCoder; + underlying = + DoFnRunners.defaultStatefulDoFnRunner( + fn, + keyedWorkItemCoder.getElementCoder(), + underlying, + stepContext, + windowingStrategy, + new StatefulDoFnRunner.TimeInternalsCleanupTimer<>( + stepContext.timerInternals(), windowingStrategy), + new StatefulDoFnRunner.StateInternalsStateCleaner<>( + fn, + stepContext.stateInternals(), + windowingStrategy.getWindowFn().windowCoder()), + true); + } return SimplePushbackSideInputDoFnRunner.create(underlying, sideInputs, sideInputReader); }; } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java index f30ee41..c0c05e4 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java @@ -155,6 +155,15 @@ public class ParDoMultiOverrideFactory<InputT, OutputT> @Override public PCollectionTuple expand(PCollection<KV<K, InputT>> input) { + PCollection<KeyedWorkItem<K, KV<K, InputT>>> adjustedInput = groupToKeyedWorkItem(input); + + return applyStatefulParDo(adjustedInput); + } + + @VisibleForTesting + PCollection<KeyedWorkItem<K, KV<K, InputT>>> groupToKeyedWorkItem( + PCollection<KV<K, InputT>> input) { + WindowingStrategy<?, ?> inputWindowingStrategy = input.getWindowingStrategy(); // A KvCoder is required since this goes through GBK. Further, WindowedValueCoder @@ -165,42 +174,46 @@ public class ParDoMultiOverrideFactory<InputT, OutputT> ParDo.class.getSimpleName(), KvCoder.class.getSimpleName(), input.getCoder()); + KvCoder<K, InputT> kvCoder = (KvCoder<K, InputT>) input.getCoder(); Coder<K> keyCoder = kvCoder.getKeyCoder(); Coder<? extends BoundedWindow> windowCoder = inputWindowingStrategy.getWindowFn().windowCoder(); - PCollection<KeyedWorkItem<K, KV<K, InputT>>> adjustedInput = - input - // Stash the original timestamps, etc, for when it is fed to the user's DoFn - .apply("Reify timestamps", ParDo.of(new ReifyWindowedValueFn<>())) - .setCoder(KvCoder.of(keyCoder, WindowedValue.getFullCoder(kvCoder, windowCoder))) - - // We are going to GBK to gather keys and windows but otherwise do not want - // to alter the flow of data. This entails: - // - trigger as fast as possible - // - maintain the full timestamps of elements - // - ensure this GBK holds to the minimum of those timestamps (via TimestampCombiner) - // - discard past panes as it is "just a stream" of elements - .apply( - Window.<KV<K, WindowedValue<KV<K, InputT>>>>configure() - .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1))) - .discardingFiredPanes() - .withAllowedLateness(inputWindowingStrategy.getAllowedLateness()) - .withTimestampCombiner(TimestampCombiner.EARLIEST)) - - // A full GBK to group by key _and_ window - .apply("Group by key", GroupByKey.create()) - - // Adapt to KeyedWorkItem; that is how this runner delivers timers - .apply("To KeyedWorkItem", ParDo.of(new ToKeyedWorkItem<>())) - .setCoder(KeyedWorkItemCoder.of(keyCoder, kvCoder, windowCoder)) - - // Because of the intervening GBK, we may have abused the windowing strategy - // of the input, which should be transferred to the output in a straightforward manner - // according to what ParDo already does. - .setWindowingStrategyInternal(inputWindowingStrategy); + return input + // Stash the original timestamps, etc, for when it is fed to the user's DoFn + .apply("Reify timestamps", ParDo.of(new ReifyWindowedValueFn<>())) + .setCoder(KvCoder.of(keyCoder, WindowedValue.getFullCoder(kvCoder, windowCoder))) + + // We are going to GBK to gather keys and windows but otherwise do not want + // to alter the flow of data. This entails: + // - trigger as fast as possible + // - maintain the full timestamps of elements + // - ensure this GBK holds to the minimum of those timestamps (via TimestampCombiner) + // - discard past panes as it is "just a stream" of elements + .apply( + Window.<KV<K, WindowedValue<KV<K, InputT>>>>configure() + .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1))) + .discardingFiredPanes() + .withAllowedLateness(inputWindowingStrategy.getAllowedLateness()) + .withTimestampCombiner(TimestampCombiner.EARLIEST)) + + // A full GBK to group by key _and_ window + .apply("Group by key", GroupByKey.create()) + + // Adapt to KeyedWorkItem; that is how this runner delivers timers + .apply("To KeyedWorkItem", ParDo.of(new ToKeyedWorkItem<>())) + .setCoder(KeyedWorkItemCoder.of(keyCoder, kvCoder, windowCoder)) + + // Because of the intervening GBK, we may have abused the windowing strategy + // of the input, which should be transferred to the output in a straightforward manner + // according to what ParDo already does. + .setWindowingStrategyInternal(inputWindowingStrategy); + } + @VisibleForTesting + PCollectionTuple applyStatefulParDo( + PCollection<KeyedWorkItem<K, KV<K, InputT>>> adjustedInput) { return adjustedInput // Explode the resulting iterable into elements that are exactly the ones from // the input diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/QuiescenceDriver.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/QuiescenceDriver.java index ca0ad61..cbd2eeb 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/QuiescenceDriver.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/QuiescenceDriver.java @@ -164,7 +164,7 @@ class QuiescenceDriver implements ExecutionDriver { transformTimers.getKey(), (PCollection) Iterables.getOnlyElement( - transformTimers.getExecutable().getInputs().values())) + transformTimers.getExecutable().getMainInputs().values())) .add(WindowedValue.valueInGlobalWindow(work)) .commit(evaluationContext.now()); outstandingWork.incrementAndGet(); diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java index 04d03e8..fd80854 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java @@ -44,6 +44,7 @@ import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo; import org.apache.beam.runners.direct.WatermarkManager.TimerUpdate; import org.apache.beam.runners.direct.WatermarkManager.TransformWatermarks; import org.apache.beam.runners.local.StructuralKey; +import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.options.PipelineOptions; @@ -52,13 +53,17 @@ import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.StateSpecs; import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.transforms.windowing.Window; @@ -67,8 +72,10 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.joda.time.Duration; @@ -135,25 +142,29 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { .apply(Window.into(FixedWindows.of(Duration.millis(10)))); TupleTag<Integer> mainOutput = new TupleTag<>(); - PCollection<Integer> produced = - input - .apply( - new ParDoMultiOverrideFactory.GbkThenStatefulParDo<>( - new DoFn<KV<String, Integer>, Integer>() { - @StateId(stateId) - private final StateSpec<ValueState<String>> spec = - StateSpecs.value(StringUtf8Coder.of()); + final ParDoMultiOverrideFactory.GbkThenStatefulParDo<String, Integer, Integer> + gbkThenStatefulParDo; + gbkThenStatefulParDo = + new ParDoMultiOverrideFactory.GbkThenStatefulParDo<>( + new DoFn<KV<String, Integer>, Integer>() { + @StateId(stateId) + private final StateSpec<ValueState<String>> spec = + StateSpecs.value(StringUtf8Coder.of()); + + @ProcessElement + public void process(ProcessContext c) {} + }, + mainOutput, + TupleTagList.empty(), + Collections.emptyList(), + DoFnSchemaInformation.create(), + Collections.emptyMap()); + + final PCollection<KeyedWorkItem<String, KV<String, Integer>>> grouped = + gbkThenStatefulParDo.groupToKeyedWorkItem(input); - @ProcessElement - public void process(ProcessContext c) {} - }, - mainOutput, - TupleTagList.empty(), - Collections.emptyList(), - DoFnSchemaInformation.create(), - Collections.emptyMap())) - .get(mainOutput) - .setCoder(VarIntCoder.of()); + PCollection<Integer> produced = + gbkThenStatefulParDo.applyStatefulParDo(grouped).get(mainOutput).setCoder(VarIntCoder.of()); StatefulParDoEvaluatorFactory<String, Integer, Integer> factory = new StatefulParDoEvaluatorFactory<>(mockEvaluationContext, options); @@ -188,15 +199,35 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { // A single bundle with some elements in the global window; it should register cleanup for the // global window state merely by having the evaluator created. The cleanup logic does not // depend on the window. - CommittedBundle<KV<String, Integer>> inputBundle = + CommittedBundle<KeyedWorkItem<String, KV<String, Integer>>> inputBundle = BUNDLE_FACTORY - .createBundle(input) + .createBundle(grouped) .add( WindowedValue.of( - KV.of("hello", 1), new Instant(3), firstWindow, PaneInfo.NO_FIRING)) + KeyedWorkItems.<String, KV<String, Integer>>elementsWorkItem( + "hello", + Collections.singleton( + WindowedValue.of( + KV.of("hello", 1), + new Instant(3), + firstWindow, + PaneInfo.NO_FIRING))), + new Instant(3), + firstWindow, + PaneInfo.NO_FIRING)) .add( WindowedValue.of( - KV.of("hello", 2), new Instant(11), secondWindow, PaneInfo.NO_FIRING)) + KeyedWorkItems.<String, KV<String, Integer>>elementsWorkItem( + "hello", + Collections.singleton( + WindowedValue.of( + KV.of("hello", 2), + new Instant(11), + secondWindow, + PaneInfo.NO_FIRING))), + new Instant(11), + secondWindow, + PaneInfo.NO_FIRING)) .commit(Instant.now()); // Merely creating the evaluator should suffice to register the cleanup callback @@ -340,4 +371,78 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { } assertThat(pushedBackInts, containsInAnyOrder(1, 13, 15)); } + + @Test + public void testRequiresTimeSortedInput() { + Instant now = Instant.ofEpochMilli(0); + PCollection<KV<String, Integer>> input = + pipeline.apply( + Create.timestamped( + TimestampedValue.of(KV.of("", 1), now.plus(2)), + TimestampedValue.of(KV.of("", 2), now.plus(1)), + TimestampedValue.of(KV.of("", 3), now))); + PCollection<String> result = input.apply(ParDo.of(statefulConcat())); + PAssert.that(result).containsInAnyOrder("3", "3:2", "3:2:1"); + pipeline.run(); + } + + @Test + public void testRequiresTimeSortedInputWithLateData() { + Instant now = Instant.ofEpochMilli(0); + PCollection<KV<String, Integer>> input = + pipeline.apply( + TestStream.create(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())) + .addElements(TimestampedValue.of(KV.of("", 1), now.plus(2))) + .addElements(TimestampedValue.of(KV.of("", 2), now.plus(1))) + .advanceWatermarkTo(now.plus(1)) + .addElements(TimestampedValue.of(KV.of("", 3), now)) + .advanceWatermarkToInfinity()); + PCollection<String> result = input.apply(ParDo.of(statefulConcat())); + PAssert.that(result).containsInAnyOrder("2", "2:1"); + pipeline.run(); + } + + @Test + public void testRequiresTimeSortedInputWithLateDataAndAllowedLateness() { + Instant now = Instant.ofEpochMilli(0); + PCollection<KV<String, Integer>> input = + pipeline + .apply( + TestStream.create(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())) + .addElements(TimestampedValue.of(KV.of("", 1), now.plus(2))) + .addElements(TimestampedValue.of(KV.of("", 2), now.plus(1))) + .advanceWatermarkTo(now.plus(1)) + .addElements(TimestampedValue.of(KV.of("", 3), now)) + .advanceWatermarkToInfinity()) + .apply( + Window.<KV<String, Integer>>into(new GlobalWindows()) + .withAllowedLateness(Duration.millis(2))); + PCollection<String> result = input.apply(ParDo.of(statefulConcat())); + PAssert.that(result).containsInAnyOrder("3", "3:2", "3:2:1"); + pipeline.run(); + } + + private static DoFn<KV<String, Integer>, String> statefulConcat() { + + final String stateId = "sum"; + + return new DoFn<KV<String, Integer>, String>() { + + @StateId(stateId) + final StateSpec<ValueState<String>> stateSpec = StateSpecs.value(); + + @ProcessElement + @RequiresTimeSortedInput + public void processElement( + ProcessContext context, @StateId(stateId) ValueState<String> state) { + String current = MoreObjects.firstNonNull(state.read(), ""); + if (!current.isEmpty()) { + current += ":"; + } + current += context.element().getValue(); + context.output(current); + state.write(current); + } + }; + } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java index 84a2e05..90d6086 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.flink; import static org.apache.beam.runners.core.construction.ExecutableStageTranslation.generateNameFromStagePayload; +import static org.apache.beam.runners.flink.translation.utils.FlinkPortableRunnerUtils.requiresTimeSortedInput; import static org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.createOutputMap; import static org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.getWindowingStrategy; import static org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.instantiateCoder; @@ -82,6 +83,7 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterable import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.apache.flink.api.common.JobExecutionResult; +import org.apache.flink.api.common.operators.Order; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; @@ -93,6 +95,7 @@ import org.apache.flink.api.java.operators.GroupReduceOperator; import org.apache.flink.api.java.operators.Grouping; import org.apache.flink.api.java.operators.MapPartitionOperator; import org.apache.flink.api.java.operators.SingleInputUdfOperator; +import org.apache.flink.api.java.operators.UnsortedGrouping; /** * A translator that translates bounded portable pipelines into executable Flink pipelines. @@ -355,8 +358,16 @@ public class FlinkBatchPortablePipelineTranslator Grouping<WindowedValue<InputT>> groupedInput = inputDataSet.groupBy(new KvKeySelector<>(keyCoder)); + boolean requiresTimeSortedInput = requiresTimeSortedInput(stagePayload, false); + if (requiresTimeSortedInput) { + groupedInput = + ((UnsortedGrouping<WindowedValue<InputT>>) groupedInput) + .sortGroup(WindowedValue::getTimestamp, Order.ASCENDING); + } + taggedDataset = new GroupReduceOperator<>(groupedInput, typeInformation, function, operatorName); + } else { taggedDataset = new MapPartitionOperator<>(inputDataSet, typeInformation, function, operatorName); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java index 28351d5..97ac1b5 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java @@ -61,6 +61,7 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.transforms.join.RawUnionValue; import org.apache.beam.sdk.transforms.join.UnionCoder; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; @@ -78,8 +79,10 @@ import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.api.common.operators.Order; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.operators.DataSource; import org.apache.flink.api.java.operators.FlatMapOperator; import org.apache.flink.api.java.operators.GroupCombineOperator; @@ -88,6 +91,7 @@ import org.apache.flink.api.java.operators.Grouping; import org.apache.flink.api.java.operators.MapOperator; import org.apache.flink.api.java.operators.MapPartitionOperator; import org.apache.flink.api.java.operators.SingleInputUdfOperator; +import org.joda.time.Instant; /** * Translators for transforming {@link PTransform PTransforms} to Flink {@link DataSet DataSets}. @@ -507,8 +511,9 @@ class FlinkBatchTransformTranslators { } catch (IOException e) { throw new RuntimeException(e); } + DoFnSignature signature = DoFnSignatures.signatureForDoFn(doFn); checkState( - !DoFnSignatures.signatureForDoFn(doFn).processElement().isSplittable(), + !signature.processElement().isSplittable(), "Not expected to directly translate splittable DoFn, should have been overridden: %s", doFn); DataSet<WindowedValue<InputT>> inputDataSet = @@ -613,8 +618,16 @@ class FlinkBatchTransformTranslators { // Based on the fact that the signature is stateful, DoFnSignatures ensures // that it is also keyed. - Grouping<WindowedValue<InputT>> grouping = - inputDataSet.groupBy(new KvKeySelector(inputCoder.getKeyCoder())); + Coder<Object> keyCoder = (Coder) inputCoder.getKeyCoder(); + final Grouping<WindowedValue<InputT>> grouping; + if (signature.processElement().requiresTimeSortedInput()) { + grouping = + inputDataSet + .groupBy((KeySelector) new KvKeySelector<>(keyCoder)) + .sortGroup(new KeyWithValueTimestampSelector<>(), Order.ASCENDING); + } else { + grouping = inputDataSet.groupBy((KeySelector) new KvKeySelector<>(keyCoder)); + } outputDataSet = new GroupReduceOperator(grouping, typeInformation, doFnWrapper, fullName); @@ -665,6 +678,15 @@ class FlinkBatchTransformTranslators { } } + private static class KeyWithValueTimestampSelector<K, V> + implements KeySelector<WindowedValue<KV<K, V>>, Instant> { + + @Override + public Instant getKey(WindowedValue<KV<K, V>> in) throws Exception { + return in.getTimestamp(); + } + } + private static class FlattenPCollectionTranslatorBatch<T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator< PTransform<PCollectionList<T>, PCollection<T>>> { diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java index cbc437b..5494749 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java @@ -692,7 +692,6 @@ public class FlinkStreamingPortablePipelineTranslator new ExecutableStageDoFnOperator<>( transform.getUniqueName(), windowedInputCoder, - null, Collections.emptyMap(), mainOutputTag, additionalOutputTags, diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index d98a601..4efdc34 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -453,7 +453,6 @@ class FlinkStreamingTransformTranslators { Map<TupleTag<?>, Coder<WindowedValue<?>>> tagsToCoders, Map<TupleTag<?>, Integer> tagsToIds, Coder<WindowedValue<InputT>> windowedInputCoder, - Coder<InputT> inputCoder, Map<TupleTag<?>, Coder<?>> outputCoders, Coder keyCoder, KeySelector<WindowedValue<InputT>, ?> keySelector, @@ -505,7 +504,6 @@ class FlinkStreamingTransformTranslators { SingleOutputStreamOperator<WindowedValue<OutputT>> outputStream; Coder<WindowedValue<InputT>> windowedInputCoder = context.getWindowedInputCoder(input); - Coder<InputT> inputCoder = context.getInputCoder(input); Map<TupleTag<?>, Coder<?>> outputCoders = context.getOutputCoders(); DataStream<WindowedValue<InputT>> inputDataStream = context.getInputDataStream(input); @@ -546,7 +544,6 @@ class FlinkStreamingTransformTranslators { tagsToCoders, tagsToIds, windowedInputCoder, - inputCoder, outputCoders, keyCoder, keySelector, @@ -574,7 +571,6 @@ class FlinkStreamingTransformTranslators { tagsToCoders, tagsToIds, windowedInputCoder, - inputCoder, outputCoders, keyCoder, keySelector, @@ -694,7 +690,6 @@ class FlinkStreamingTransformTranslators { tagsToCoders, tagsToIds, windowedInputCoder, - inputCoder, outputCoders1, keyCoder, keySelector, @@ -705,7 +700,6 @@ class FlinkStreamingTransformTranslators { doFn1, stepName, windowedInputCoder, - inputCoder, outputCoders1, mainOutputTag1, additionalOutputTags1, @@ -756,7 +750,6 @@ class FlinkStreamingTransformTranslators { tagsToCoders, tagsToIds, windowedInputCoder, - inputCoder, outputCoders1, keyCoder, keySelector, @@ -767,7 +760,6 @@ class FlinkStreamingTransformTranslators { doFn, stepName, windowedInputCoder, - inputCoder, outputCoders1, mainOutputTag, additionalOutputTags, diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/FlinkPortableRunnerUtils.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/FlinkPortableRunnerUtils.java new file mode 100644 index 0000000..3e00545 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/FlinkPortableRunnerUtils.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.utils; + +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.construction.PTransformTranslation; +import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException; + +/** + * Various utilies related to portability. Helps share code between portable batch and streaming + * translator. + */ +public class FlinkPortableRunnerUtils { + + public static boolean requiresTimeSortedInput( + RunnerApi.ExecutableStagePayload payload, boolean streaming) { + + boolean requiresTimeSortedInput = + payload.getComponents().getTransformsMap().values().stream() + .filter(t -> t.getSpec().getUrn().equals(PTransformTranslation.PAR_DO_TRANSFORM_URN)) + .anyMatch( + t -> { + try { + return RunnerApi.ParDoPayload.parseFrom(t.getSpec().getPayload()) + .getRequiresTimeSortedInput(); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }); + + if (streaming && requiresTimeSortedInput) { + // until https://issues.apache.org/jira/browse/BEAM-8460 is resolved, we must + // throw UnsupportedOperationException here to prevent data loss. + throw new UnsupportedOperationException( + "https://issues.apache.org/jira/browse/BEAM-8460 blocks this feature for now."); + } + + return requiresTimeSortedInput; + } + + /** Do not construct. */ + private FlinkPortableRunnerUtils() {} +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java index fc4896e..37ed3c3 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java @@ -158,8 +158,6 @@ public class DoFnOperator<InputT, OutputT> extends AbstractStreamOperator<Window private final Coder<WindowedValue<InputT>> windowedInputCoder; - private final Coder<InputT> inputCoder; - private final Map<TupleTag<?>, Coder<?>> outputCoders; protected final Coder<?> keyCoder; @@ -220,7 +218,6 @@ public class DoFnOperator<InputT, OutputT> extends AbstractStreamOperator<Window DoFn<InputT, OutputT> doFn, String stepName, Coder<WindowedValue<InputT>> inputWindowedCoder, - Coder<InputT> inputCoder, Map<TupleTag<?>, Coder<?>> outputCoders, TupleTag<OutputT> mainOutputTag, List<TupleTag<?>> additionalOutputTags, @@ -236,7 +233,6 @@ public class DoFnOperator<InputT, OutputT> extends AbstractStreamOperator<Window this.doFn = doFn; this.stepName = stepName; this.windowedInputCoder = inputWindowedCoder; - this.inputCoder = inputCoder; this.outputCoders = outputCoders; this.mainOutputTag = mainOutputTag; this.additionalOutputTags = additionalOutputTags; @@ -294,7 +290,7 @@ public class DoFnOperator<InputT, OutputT> extends AbstractStreamOperator<Window // stateful DoFn runner because ProcessFn, which is used for executing a Splittable DoFn // doesn't play by the normal DoFn rules and WindowDoFnOperator uses LateDataDroppingDoFnRunner protected DoFnRunner<InputT, OutputT> createWrappingDoFnRunner( - DoFnRunner<InputT, OutputT> wrappedRunner) { + DoFnRunner<InputT, OutputT> wrappedRunner, StepContext stepContext) { if (keyCoder != null) { StatefulDoFnRunner.CleanupTimer cleanupTimer = @@ -310,7 +306,14 @@ public class DoFnOperator<InputT, OutputT> extends AbstractStreamOperator<Window doFn, keyedStateInternals, windowCoder); return DoFnRunners.defaultStatefulDoFnRunner( - doFn, wrappedRunner, windowingStrategy, cleanupTimer, stateCleaner); + doFn, + getInputCoder(), + wrappedRunner, + stepContext, + windowingStrategy, + cleanupTimer, + stateCleaner, + true /* requiresTimeSortedInput is supported */); } else { return doFnRunner; @@ -375,17 +378,6 @@ public class DoFnOperator<InputT, OutputT> extends AbstractStreamOperator<Window keyedStateInternals = new FlinkStateInternals<>((KeyedStateBackend) getKeyedStateBackend(), keyCoder); - if (doFn != null) { - DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); - FlinkStateInternals.EarlyBinder earlyBinder = - new FlinkStateInternals.EarlyBinder(getKeyedStateBackend()); - for (DoFnSignature.StateDeclaration value : signature.stateDeclarations().values()) { - StateSpec<?> spec = - (StateSpec<?>) signature.stateDeclarations().get(value.id()).field().get(doFn); - spec.bind(value.id(), earlyBinder); - } - } - if (timerService == null) { timerService = getInternalTimerService("beam-timer", new CoderTypeSerializer<>(timerCoder), this); @@ -417,6 +409,7 @@ public class DoFnOperator<InputT, OutputT> extends AbstractStreamOperator<Window doFnInvoker.invokeSetup(); FlinkPipelineOptions options = serializedOptions.get().as(FlinkPipelineOptions.class); + StepContext stepContext = new FlinkStepContext(); doFnRunner = DoFnRunners.simpleRunner( options, @@ -425,8 +418,8 @@ public class DoFnOperator<InputT, OutputT> extends AbstractStreamOperator<Window outputManager, mainOutputTag, additionalOutputTags, - new FlinkStepContext(), - inputCoder, + stepContext, + getInputCoder(), outputCoders, windowingStrategy, doFnSchemaInformation, @@ -444,7 +437,8 @@ public class DoFnOperator<InputT, OutputT> extends AbstractStreamOperator<Window getOperatorStateBackend(), getKeyedStateBackend()); } - doFnRunner = createWrappingDoFnRunner(doFnRunner); + doFnRunner = createWrappingDoFnRunner(doFnRunner, stepContext); + earlyBindStateIfNeeded(); if (!options.getDisableMetrics()) { flinkMetricContainer = new FlinkMetricContainer(getRuntimeContext()); @@ -470,6 +464,26 @@ public class DoFnOperator<InputT, OutputT> extends AbstractStreamOperator<Window } } + private void earlyBindStateIfNeeded() throws IllegalArgumentException, IllegalAccessException { + if (keyCoder != null) { + if (doFn != null) { + DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); + FlinkStateInternals.EarlyBinder earlyBinder = + new FlinkStateInternals.EarlyBinder(getKeyedStateBackend()); + for (DoFnSignature.StateDeclaration value : signature.stateDeclarations().values()) { + StateSpec<?> spec = + (StateSpec<?>) signature.stateDeclarations().get(value.id()).field().get(doFn); + spec.bind(value.id(), earlyBinder); + } + if (doFnRunner instanceof StatefulDoFnRunner) { + ((StatefulDoFnRunner<InputT, OutputT, BoundedWindow>) doFnRunner) + .getSystemStateTags() + .forEach(tag -> tag.getSpec().bind(tag.getId(), earlyBinder)); + } + } + } + } + @Override public void dispose() throws Exception { try { @@ -851,6 +865,11 @@ public class DoFnOperator<InputT, OutputT> extends AbstractStreamOperator<Window this.currentOutputWatermark = currentOutputWatermark; } + @SuppressWarnings("unchecked") + Coder<InputT> getInputCoder() { + return (Coder<InputT>) Iterables.getOnlyElement(windowedInputCoder.getCoderArguments()); + } + /** Factory for creating an {@link BufferedOutputManager} from a Flink {@link Output}. */ interface OutputManagerFactory<OutputT> extends Serializable { BufferedOutputManager<OutputT> create( diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java index 1029eb7..242efc6 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming; +import static org.apache.beam.runners.flink.translation.utils.FlinkPortableRunnerUtils.requiresTimeSortedInput; import static org.apache.flink.util.Preconditions.checkNotNull; import java.io.IOException; @@ -53,6 +54,7 @@ import org.apache.beam.runners.core.StateNamespaces; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; import org.apache.beam.runners.core.StatefulDoFnRunner; +import org.apache.beam.runners.core.StepContext; import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.construction.Timer; import org.apache.beam.runners.core.construction.graph.ExecutableStage; @@ -124,6 +126,8 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I /** A lock which has to be acquired when concurrently accessing state and timers. */ private final ReentrantLock stateBackendLock; + private final boolean isStateful; + private transient ExecutableStageContext stageContext; private transient StateRequestHandler stateRequestHandler; private transient BundleProgressHandler progressHandler; @@ -141,7 +145,6 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I public ExecutableStageDoFnOperator( String stepName, Coder<WindowedValue<InputT>> windowedInputCoder, - Coder<InputT> inputCoder, Map<TupleTag<?>, Coder<?>> outputCoders, TupleTag<OutputT> mainOutputTag, List<TupleTag<?>> additionalOutputTags, @@ -161,7 +164,6 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I new NoOpDoFn(), stepName, windowedInputCoder, - inputCoder, outputCoders, mainOutputTag, additionalOutputTags, @@ -174,6 +176,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I keySelector, DoFnSchemaInformation.create(), Collections.emptyMap()); + this.isStateful = payload.getUserStatesCount() > 0 || payload.getTimersCount() > 0; this.payload = payload; this.jobInfo = jobInfo; this.contextFactory = contextFactory; @@ -242,7 +245,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I } final StateRequestHandler userStateRequestHandler; - if (executableStage.getUserStates().size() > 0) { + if (!executableStage.getUserStates().isEmpty()) { if (keyedStateInternals == null) { throw new IllegalStateException("Input must be keyed when user state is used"); } @@ -316,8 +319,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I @Override public Iterable<V> get(ByteString key, W window) { - try { - stateBackendLock.lock(); + try (Locker locker = Locker.locked(stateBackendLock)) { prepareStateBackend(key); StateNamespace namespace = StateNamespaces.window(windowCoder, window); if (LOG.isDebugEnabled()) { @@ -332,15 +334,12 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I stateInternals.state(namespace, StateTags.bag(userStateId, valueCoder)); return bagState.read(); - } finally { - stateBackendLock.unlock(); } } @Override public void append(ByteString key, W window, Iterator<V> values) { - try { - stateBackendLock.lock(); + try (Locker locker = Locker.locked(stateBackendLock)) { prepareStateBackend(key); StateNamespace namespace = StateNamespaces.window(windowCoder, window); if (LOG.isDebugEnabled()) { @@ -356,15 +355,12 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I while (values.hasNext()) { bagState.add(values.next()); } - } finally { - stateBackendLock.unlock(); } } @Override public void clear(ByteString key, W window) { - try { - stateBackendLock.lock(); + try (Locker locker = Locker.locked(stateBackendLock)) { prepareStateBackend(key); StateNamespace namespace = StateNamespaces.window(windowCoder, window); if (LOG.isDebugEnabled()) { @@ -378,8 +374,6 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I BagState<V> bagState = stateInternals.state(namespace, StateTags.bag(userStateId, valueCoder)); bagState.clear(); - } finally { - stateBackendLock.unlock(); } } @@ -449,8 +443,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I ByteBuffer encodedKey = (ByteBuffer) keySelector.getKey(timerElement); // We have to synchronize to ensure the state backend is not concurrently accessed by the // state requests - try { - stateBackendLock.lock(); + try (Locker locker = Locker.locked(stateBackendLock)) { getKeyedStateBackend().setCurrentKey(encodedKey); if (timerData.getTimestamp().isAfter(BoundedWindow.TIMESTAMP_MAX_VALUE)) { timerInternals.deleteTimer( @@ -458,8 +451,6 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I } else { timerInternals.setTimer(timerData); } - } finally { - stateBackendLock.unlock(); } } catch (Exception e) { throw new RuntimeException("Couldn't set timer", e); @@ -471,12 +462,9 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I final ByteBuffer encodedKey = timer.getKey(); // We have to synchronize to ensure the state backend is not concurrently accessed by the state // requests - try { - stateBackendLock.lock(); + try (Locker locker = Locker.locked(stateBackendLock)) { getKeyedStateBackend().setCurrentKey(encodedKey); super.fireTimer(timer); - } finally { - stateBackendLock.unlock(); } } @@ -508,9 +496,10 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I @Override protected DoFnRunner<InputT, OutputT> createWrappingDoFnRunner( - DoFnRunner<InputT, OutputT> wrappedRunner) { + DoFnRunner<InputT, OutputT> wrappedRunner, StepContext stepContext) { sdkHarnessRunner = new SdkHarnessDoFnRunner<>( + wrappedRunner.getFn(), executableStage.getInputPCollection().getId(), stageBundleFactory, stateRequestHandler, @@ -521,7 +510,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I this::setTimer, () -> FlinkKeyUtils.decodeKey(getCurrentKey(), keyCoder)); - return ensureStateCleanup(sdkHarnessRunner); + return ensureStateDoFnRunner(sdkHarnessRunner, payload, stepContext); } @Override @@ -581,6 +570,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I private static class SdkHarnessDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, OutputT> { + private final DoFn<InputT, OutputT> doFn; private final String mainInput; private final LinkedBlockingQueue<KV<String, OutputT>> outputQueue; private final StageBundleFactory stageBundleFactory; @@ -607,6 +597,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I private volatile FnDataReceiver<WindowedValue<?>> mainInputReceiver; public SdkHarnessDoFnRunner( + DoFn<InputT, OutputT> doFn, String mainInput, StageBundleFactory stageBundleFactory, StateRequestHandler stateRequestHandler, @@ -616,6 +607,8 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I Coder<BoundedWindow> windowCoder, BiConsumer<WindowedValue<InputT>, TimerInternals.TimerData> timerRegistration, Supplier<Object> keyForTimer) { + + this.doFn = doFn; this.mainInput = mainInput; this.stageBundleFactory = stageBundleFactory; this.stateRequestHandler = stateRequestHandler; @@ -762,15 +755,16 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I @Override public DoFn<InputT, OutputT> getFn() { - throw new UnsupportedOperationException(); + return doFn; } } - private DoFnRunner<InputT, OutputT> ensureStateCleanup( - SdkHarnessDoFnRunner<InputT, OutputT> sdkHarnessRunner) { - if (keyCoder == null) { - // There won't be any state to clean up - // (stateful functions have to be keyed) + private DoFnRunner<InputT, OutputT> ensureStateDoFnRunner( + SdkHarnessDoFnRunner<InputT, OutputT> sdkHarnessRunner, + RunnerApi.ExecutableStagePayload payload, + StepContext stepContext) { + + if (!isStateful) { return sdkHarnessRunner; } // Takes care of state cleanup via StatefulDoFnRunner @@ -794,19 +788,33 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I new StateCleaner(userStates, windowCoder, () -> stateBackend.getCurrentKey()); return new StatefulDoFnRunner<InputT, OutputT, BoundedWindow>( - sdkHarnessRunner, windowingStrategy, cleanupTimer, stateCleaner) { + sdkHarnessRunner, + getInputCoder(), + stepContext, + windowingStrategy, + cleanupTimer, + stateCleaner, + requiresTimeSortedInput(payload, true)) { + + @Override + public void processElement(WindowedValue<InputT> input) { + try (Locker locker = Locker.locked(stateBackendLock)) { + @SuppressWarnings({"unchecked", "rawtypes"}) + final ByteBuffer key = + FlinkKeyUtils.encodeKey(((KV) input.getValue()).getKey(), (Coder) keyCoder); + getKeyedStateBackend().setCurrentKey(key); + super.processElement(input); + } + } + @Override public void finishBundle() { // Before cleaning up state, first finish bundle for all underlying DoFnRunners super.finishBundle(); // execute cleanup after the bundle is complete if (!stateCleaner.cleanupQueue.isEmpty()) { - try { - stateBackendLock.lock(); - stateCleaner.cleanupState( - keyedStateInternals, (key) -> stateBackend.setCurrentKey(key)); - } finally { - stateBackendLock.unlock(); + try (Locker locker = Locker.locked(stateBackendLock)) { + stateCleaner.cleanupState(keyedStateInternals, stateBackend::setCurrentKey); } } } @@ -839,11 +847,6 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I } @Override - public Instant currentInputWatermarkTime() { - return timerInternals.currentInputWatermarkTime(); - } - - @Override public void setForWindow(InputT input, BoundedWindow window) { Preconditions.checkNotNull(input, "Null input passed to CleanupTimer"); // make sure this fires after any window.maxTimestamp() timers @@ -851,8 +854,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I // needs to match the encoding in prepareStateBackend for state request handler final ByteBuffer key = FlinkKeyUtils.encodeKey(((KV) input).getKey(), keyCoder); // Ensure the state backend is not concurrently accessed by the state requests - try { - stateBackendLock.lock(); + try (Locker locker = Locker.locked(stateBackendLock)) { keyedStateBackend.setCurrentKey(key); timerInternals.setTimer( StateNamespaces.window(windowCoder, window), @@ -861,8 +863,6 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I gcTime, window.maxTimestamp(), TimeDomain.EVENT_TIME); - } finally { - stateBackendLock.unlock(); } } @@ -938,4 +938,24 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I @ProcessElement public void doNothing(ProcessContext context) {} } + + private static class Locker implements AutoCloseable { + + public static Locker locked(Lock lock) { + Locker locker = new Locker(lock); + lock.lock(); + return locker; + } + + private final Lock lock; + + Locker(Lock lock) { + this.lock = lock; + } + + @Override + public void close() { + lock.unlock(); + } + } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java index e955c19..4796567 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java @@ -35,6 +35,7 @@ import org.apache.beam.runners.core.OutputWindowedValue; import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems.ProcessFn; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateInternalsFactory; +import org.apache.beam.runners.core.StepContext; import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.TimerInternalsFactory; import org.apache.beam.sdk.coders.Coder; @@ -68,7 +69,6 @@ public class SplittableDoFnOperator<InputT, OutputT, RestrictionT> DoFn<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> doFn, String stepName, Coder<WindowedValue<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>> windowedInputCoder, - Coder<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> inputCoder, Map<TupleTag<?>, Coder<?>> outputCoders, TupleTag<OutputT> mainOutputTag, List<TupleTag<?>> additionalOutputTags, @@ -83,7 +83,6 @@ public class SplittableDoFnOperator<InputT, OutputT, RestrictionT> doFn, stepName, windowedInputCoder, - inputCoder, outputCoders, mainOutputTag, additionalOutputTags, @@ -101,7 +100,8 @@ public class SplittableDoFnOperator<InputT, OutputT, RestrictionT> @Override protected DoFnRunner<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> createWrappingDoFnRunner( - DoFnRunner<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> wrappedRunner) { + DoFnRunner<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> wrappedRunner, + StepContext stepContext) { // don't wrap in anything because we don't need state cleanup because ProcessFn does // all that return wrappedRunner; diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java index 8b4cb24..6101379 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java @@ -31,6 +31,7 @@ import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.KeyedWorkItems; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateInternalsFactory; +import org.apache.beam.runners.core.StepContext; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.core.TimerInternalsFactory; import org.apache.beam.sdk.coders.Coder; @@ -69,7 +70,6 @@ public class WindowDoFnOperator<K, InputT, OutputT> null, stepName, windowedInputCoder, - null, Collections.emptyMap(), mainOutputTag, additionalOutputTags, @@ -88,7 +88,7 @@ public class WindowDoFnOperator<K, InputT, OutputT> @Override protected DoFnRunner<KeyedWorkItem<K, InputT>, KV<K, OutputT>> createWrappingDoFnRunner( - DoFnRunner<KeyedWorkItem<K, InputT>, KV<K, OutputT>> wrappedRunner) { + DoFnRunner<KeyedWorkItem<K, InputT>, KV<K, OutputT>> wrappedRunner, StepContext stepContext) { // When the doFn is this, we know it came from WindowDoFnOperator and // InputT = KeyedWorkItem<K, V> // OutputT = KV<K, V> diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java index 48c4ed5..0976406 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java @@ -102,7 +102,6 @@ public class FlinkPipelineOptionsTest { new TestDoFn(), "stepName", coder, - null, Collections.emptyMap(), mainTag, Collections.emptyList(), @@ -129,7 +128,6 @@ public class FlinkPipelineOptionsTest { new TestDoFn(), "stepName", coder, - null, Collections.emptyMap(), mainTag, Collections.emptyList(), diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java index 235a2e3..e4ca6b0 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java @@ -134,7 +134,6 @@ public class DoFnOperatorTest { new IdentityDoFn<>(), "stepName", coder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -195,7 +194,6 @@ public class DoFnOperatorTest { new MultiOutputDoFn(additionalOutput1, additionalOutput2), "stepName", coder, - null, Collections.emptyMap(), mainOutput, ImmutableList.of(additionalOutput1, additionalOutput2), @@ -311,7 +309,6 @@ public class DoFnOperatorTest { fn, "stepName", inputCoder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -405,7 +402,6 @@ public class DoFnOperatorTest { fn, "stepName", inputCoder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -514,7 +510,6 @@ public class DoFnOperatorTest { fn, "stepName", coder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -620,7 +615,6 @@ public class DoFnOperatorTest { new IdentityDoFn<>(), "stepName", coder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -792,7 +786,6 @@ public class DoFnOperatorTest { new IdentityDoFn<>(), "stepName", coder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -831,7 +824,6 @@ public class DoFnOperatorTest { new IdentityDoFn<>(), "stepName", coder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -929,7 +921,6 @@ public class DoFnOperatorTest { new IdentityDoFn<>(), "stepName", coder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -969,7 +960,6 @@ public class DoFnOperatorTest { new IdentityDoFn<>(), "stepName", coder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -1145,7 +1135,6 @@ public class DoFnOperatorTest { fn, "stepName", inputCoder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -1193,7 +1182,6 @@ public class DoFnOperatorTest { doFn, "stepName", windowedValueCoder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -1243,7 +1231,6 @@ public class DoFnOperatorTest { doFn, "stepName", windowedValueCoder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -1341,7 +1328,6 @@ public class DoFnOperatorTest { doFn, "stepName", windowedValueCoder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -1396,7 +1382,6 @@ public class DoFnOperatorTest { doFn, "stepName", windowedValueCoder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -1461,7 +1446,6 @@ public class DoFnOperatorTest { new IdentityDoFn(), "stepName", windowedValueCoder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -1574,7 +1558,6 @@ public class DoFnOperatorTest { doFn, "stepName", windowedValueCoder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -1690,7 +1673,6 @@ public class DoFnOperatorTest { doFn, "stepName", windowedValueCoder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -1804,7 +1786,6 @@ public class DoFnOperatorTest { doFn, "stepName", windowedValueCoder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -1847,7 +1828,6 @@ public class DoFnOperatorTest { }, "stepName", windowedValueCoder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), @@ -1937,7 +1917,6 @@ public class DoFnOperatorTest { doFn, "stepName", windowedValueCoder, - null, Collections.emptyMap(), outputTag, Collections.emptyList(), diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java index 3c5f44b..29a59d5 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java @@ -813,8 +813,7 @@ public class ExecutableStageDoFnOperatorTest { ExecutableStageDoFnOperator<Integer, Integer> operator = new ExecutableStageDoFnOperator<>( "transform", - null, - null, + WindowedValue.getValueOnlyCoder(VarIntCoder.of()), Collections.emptyMap(), mainOutput, ImmutableList.of(additionalOutput), @@ -860,7 +859,7 @@ public class ExecutableStageDoFnOperatorTest { DoFnOperator.MultiOutputOutputManagerFactory<Integer> outputManagerFactory, WindowingStrategy windowingStrategy, @Nullable Coder keyCoder, - @Nullable Coder windowedInputCoder) { + Coder windowedInputCoder) { FlinkExecutableStageContextFactory contextFactory = Mockito.mock(FlinkExecutableStageContextFactory.class); @@ -877,7 +876,6 @@ public class ExecutableStageDoFnOperatorTest { new ExecutableStageDoFnOperator<>( "transform", windowedInputCoder, - null, Collections.emptyMap(), mainOutput, additionalOutputs, diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java index eb42be2..e910e01 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java @@ -263,6 +263,11 @@ public class PrimitiveParDoSingleFactory<InputT, OutputT> } @Override + public boolean isRequiresTimeSortedInput() { + return signature.processElement().requiresTimeSortedInput(); + } + + @Override public String translateRestrictionCoderId(SdkComponents newComponents) { if (signature.processElement().isSplittable()) { Coder<?> restrictionCoder = diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java index 99439a2..6cbf269 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java @@ -95,6 +95,7 @@ public class SamzaDoFnRunners { timerInternals = timerInternalsFactory.timerInternalsForKey(null); } + final StepContext stepContext = createStepContext(stateInternals, timerInternals); final DoFnRunner<InT, FnOutT> underlyingRunner = DoFnRunners.simpleRunner( pipelineOptions, @@ -103,7 +104,7 @@ public class SamzaDoFnRunners { outputManager, mainOutputTag, sideOutputTags, - createStepContext(stateInternals, timerInternals), + stepContext, inputCoder, outputCoders, windowingStrategy, @@ -120,7 +121,9 @@ public class SamzaDoFnRunners { final DoFnRunner<InT, FnOutT> statefulDoFnRunner = DoFnRunners.defaultStatefulDoFnRunner( doFn, + inputCoder, doFnRunnerWithMetrics, + stepContext, windowingStrategy, new StatefulDoFnRunner.TimeInternalsCleanupTimer(timerInternals, windowingStrategy), createStateCleaner(doFn, windowingStrategy, keyedInternals.stateInternals())); diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java index 55d8c3f..45b5d1b 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java @@ -23,6 +23,7 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -33,6 +34,7 @@ import org.apache.beam.runners.spark.util.ByteArray; import org.apache.beam.sdk.coders.Coder; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; +import org.joda.time.Instant; import scala.Tuple2; /** Serialization utility class. */ @@ -58,6 +60,29 @@ public final class CoderHelpers { } /** + * Utility method for serializing an object using the specified coder, appending timestamp + * representation. This is useful when sorting by timestamp + * + * @param value Value to serialize. + * @param coder Coder to serialize with. + * @param timestamp timestamp to be bundled into key's ByteArray representation + * @param <T> type of value that is serialized + * @return Byte array representing serialized object. + */ + public static <T> byte[] toByteArrayWithTs(T value, Coder<T> coder, Instant timestamp) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + coder.encode(value, baos); + ByteBuffer buf = ByteBuffer.allocate(8); + buf.asLongBuffer().put(timestamp.getMillis()); + baos.write(buf.array()); + } catch (IOException e) { + throw new IllegalStateException("Error encoding value: " + value, e); + } + return baos.toByteArray(); + } + + /** * Utility method for serializing a Iterable of values using the specified coder. * * @param values Values to serialize. @@ -144,6 +169,28 @@ public final class CoderHelpers { } /** + * A function wrapper for converting a key-value pair to a byte array pair, where the key in + * resulting ByteArray contains (key, timestamp). + * + * @param keyCoder Coder to serialize keys. + * @param valueCoder Coder to serialize values. + * @param timestamp timestamp of the input Tuple2 + * @param <K> The type of the key being serialized. + * @param <V> The type of the value being serialized. + * @return A function that accepts a key-value pair and returns a pair of byte arrays. + */ + public static <K, V> PairFunction<Tuple2<K, V>, ByteArray, byte[]> toByteFunctionWithTs( + final Coder<K> keyCoder, + final Coder<V> valueCoder, + Function<Tuple2<K, V>, Instant> timestamp) { + + return kv -> + new Tuple2<>( + new ByteArray(toByteArrayWithTs(kv._1(), keyCoder, timestamp.call(kv))), + toByteArray(kv._2(), valueCoder)); + } + + /** * A function for converting a byte array pair to a key-value pair. * * @param <K> The type of the key being deserialized. diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 4494b5d..94ed36c 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -21,8 +21,10 @@ import static org.apache.beam.runners.spark.translation.TranslationUtils.canAvoi import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; +import java.util.Iterator; import java.util.Map; import javax.annotation.Nullable; import org.apache.beam.runners.core.SystemReduceFn; @@ -33,6 +35,7 @@ import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.io.SourceRDD; import org.apache.beam.runners.spark.metrics.MetricsAccumulator; import org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator; +import org.apache.beam.runners.spark.util.ByteArray; import org.apache.beam.runners.spark.util.SideInputBroadcast; import org.apache.beam.runners.spark.util.SparkCompat; import org.apache.beam.sdk.coders.CannotProvideCoderException; @@ -63,14 +66,19 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.FluentIterable; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.spark.HashPartitioner; import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.storage.StorageLevel; +import scala.Tuple2; /** Supports translation between a Beam transform, and Spark's operations on RDDs. */ public final class TransformTranslator { @@ -395,7 +403,8 @@ public final class TransformTranslator { windowingStrategy.getWindowFn().windowCoder(), (JavaRDD) inRDD, getPartitioner(context), - (MultiDoFnFunction) multiDoFnFunction); + (MultiDoFnFunction) multiDoFnFunction, + signature.processElement().requiresTimeSortedInput()); } else { all = inRDD.mapPartitionsToPair(multiDoFnFunction); } @@ -439,27 +448,159 @@ public final class TransformTranslator { Coder<? extends BoundedWindow> windowCoder, JavaRDD<WindowedValue<KV<K, V>>> kvInRDD, Partitioner partitioner, - MultiDoFnFunction<KV<K, V>, OutputT> doFnFunction) { + MultiDoFnFunction<KV<K, V>, OutputT> doFnFunction, + boolean requiresSortedInput) { Coder<K> keyCoder = kvCoder.getKeyCoder(); final WindowedValue.WindowedValueCoder<V> wvCoder = WindowedValue.FullWindowedValueCoder.of(kvCoder.getValueCoder(), windowCoder); - JavaRDD<KV<K, Iterable<WindowedValue<V>>>> groupRDD = - GroupCombineFunctions.groupByKeyOnly(kvInRDD, keyCoder, wvCoder, partitioner); - - return groupRDD - .map( - input -> { - final K key = input.getKey(); - Iterable<WindowedValue<V>> value = input.getValue(); - return FluentIterable.from(value) - .transform( - windowedValue -> - windowedValue.withValue(KV.of(key, windowedValue.getValue()))) - .iterator(); - }) - .flatMapToPair(doFnFunction); + if (!requiresSortedInput) { + return GroupCombineFunctions.groupByKeyOnly(kvInRDD, keyCoder, wvCoder, partitioner) + .map( + input -> { + final K key = input.getKey(); + Iterable<WindowedValue<V>> value = input.getValue(); + return FluentIterable.from(value) + .transform( + windowedValue -> + windowedValue.withValue(KV.of(key, windowedValue.getValue()))) + .iterator(); + }) + .flatMapToPair(doFnFunction); + } + + JavaPairRDD<ByteArray, byte[]> pairRDD = + kvInRDD + .map(new ReifyTimestampsAndWindowsFunction<>()) + .mapToPair(TranslationUtils.toPairFunction()) + .mapToPair( + CoderHelpers.toByteFunctionWithTs(keyCoder, wvCoder, in -> in._2().getTimestamp())); + + JavaPairRDD<ByteArray, byte[]> sorted = + pairRDD.repartitionAndSortWithinPartitions(keyPrefixPartitionerFrom(partitioner)); + + return sorted.mapPartitionsToPair(wrapDoFnFromSortedRDD(doFnFunction, keyCoder, wvCoder)); + } + + private static Partitioner keyPrefixPartitionerFrom(Partitioner partitioner) { + return new Partitioner() { + @Override + public int numPartitions() { + return partitioner.numPartitions(); + } + + @Override + public int getPartition(Object o) { + ByteArray b = (ByteArray) o; + return partitioner.getPartition( + new ByteArray(Arrays.copyOfRange(b.getValue(), 0, b.getValue().length - 8))); + } + }; + } + + private static <K, V, OutputT> + PairFlatMapFunction<Iterator<Tuple2<ByteArray, byte[]>>, TupleTag<?>, WindowedValue<?>> + wrapDoFnFromSortedRDD( + MultiDoFnFunction<KV<K, V>, OutputT> doFnFunction, + Coder<K> keyCoder, + Coder<WindowedValue<V>> wvCoder) { + + return (Iterator<Tuple2<ByteArray, byte[]>> in) -> { + Iterator<Iterator<Tuple2<TupleTag<?>, WindowedValue<?>>>> mappedGroups; + mappedGroups = + Iterators.transform( + splitBySameKey(in, keyCoder, wvCoder), + group -> { + try { + return doFnFunction.call(group); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + }); + return flatten(mappedGroups); + }; + } + + @VisibleForTesting + static <T> Iterator<T> flatten(final Iterator<Iterator<T>> toFlatten) { + + return new AbstractIterator<T>() { + + @Nullable Iterator<T> current = null; + + @Override + protected T computeNext() { + while (true) { + if (current == null) { + if (toFlatten.hasNext()) { + current = toFlatten.next(); + } else { + return endOfData(); + } + } + if (current.hasNext()) { + return current.next(); + } + current = null; + } + } + }; + } + + @VisibleForTesting + static <K, V> Iterator<Iterator<WindowedValue<KV<K, V>>>> splitBySameKey( + Iterator<Tuple2<ByteArray, byte[]>> in, Coder<K> keyCoder, Coder<WindowedValue<V>> wvCoder) { + + return new AbstractIterator<Iterator<WindowedValue<KV<K, V>>>>() { + + @Nullable Tuple2<ByteArray, byte[]> read = null; + + @Override + protected Iterator<WindowedValue<KV<K, V>>> computeNext() { + readNext(); + if (read != null) { + byte[] value = read._1().getValue(); + byte[] keyPart = Arrays.copyOfRange(value, 0, value.length - 8); + K key = CoderHelpers.fromByteArray(keyPart, keyCoder); + return createIteratorForKey(keyPart, key); + } + return endOfData(); + } + + private void readNext() { + if (read == null) { + if (in.hasNext()) { + read = in.next(); + } + } + } + + private void consumed() { + read = null; + } + + private Iterator<WindowedValue<KV<K, V>>> createIteratorForKey(byte[] keyPart, K key) { + + return new AbstractIterator<WindowedValue<KV<K, V>>>() { + @Override + protected WindowedValue<KV<K, V>> computeNext() { + readNext(); + if (read != null) { + byte[] value = read._1().getValue(); + byte[] prefix = Arrays.copyOfRange(value, 0, value.length - 8); + if (Arrays.equals(prefix, keyPart)) { + WindowedValue<V> wv = CoderHelpers.fromByteArray(read._2(), wvCoder); + consumed(); + return WindowedValue.of( + KV.of(key, wv.getValue()), wv.getTimestamp(), wv.getWindows(), wv.getPane()); + } + } + return endOfData(); + } + }; + } + }; } private static <T> TransformEvaluator<Read.Bounded<T>> readBounded() { diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java new file mode 100644 index 0000000..b899665 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation; + +import static org.junit.Assert.assertEquals; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.apache.beam.runners.spark.coders.CoderHelpers; +import org.apache.beam.runners.spark.util.ByteArray; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators; +import org.joda.time.Instant; +import org.junit.Test; +import org.spark_project.guava.collect.Iterables; +import scala.Tuple2; + +/** Test suite for {@link TransformTranslator}. */ +public class TransformTranslatorTest { + + @Test + public void testIteratorFlatten() { + List<Integer> first = Arrays.asList(1, 2, 3); + List<Integer> second = Arrays.asList(4, 5, 6); + List<Integer> result = new ArrayList<>(); + Iterators.addAll( + result, + TransformTranslator.flatten(Arrays.asList(first.iterator(), second.iterator()).iterator())); + assertEquals(Arrays.asList(1, 2, 3, 4, 5, 6), result); + } + + @Test + public void testSplitBySameKey() { + VarIntCoder coder = VarIntCoder.of(); + WindowedValue.WindowedValueCoder<Integer> wvCoder = + WindowedValue.FullWindowedValueCoder.of(coder, GlobalWindow.Coder.INSTANCE); + Instant now = Instant.now(); + List<GlobalWindow> window = Arrays.asList(GlobalWindow.INSTANCE); + PaneInfo paneInfo = PaneInfo.NO_FIRING; + List<Tuple2<ByteArray, byte[]>> firstKey = + Arrays.asList( + new Tuple2( + new ByteArray(CoderHelpers.toByteArrayWithTs(1, coder, now)), + CoderHelpers.toByteArray(WindowedValue.of(1, now, window, paneInfo), wvCoder)), + new Tuple2( + new ByteArray(CoderHelpers.toByteArrayWithTs(1, coder, now.plus(1))), + CoderHelpers.toByteArray( + WindowedValue.of(2, now.plus(1), window, paneInfo), wvCoder))); + + List<Tuple2<ByteArray, byte[]>> secondKey = + Arrays.asList( + new Tuple2( + new ByteArray(CoderHelpers.toByteArrayWithTs(2, coder, now)), + CoderHelpers.toByteArray(WindowedValue.of(3, now, window, paneInfo), wvCoder)), + new Tuple2( + new ByteArray(CoderHelpers.toByteArrayWithTs(2, coder, now.plus(2))), + CoderHelpers.toByteArray( + WindowedValue.of(4, now.plus(2), window, paneInfo), wvCoder))); + + Iterable<Tuple2<ByteArray, byte[]>> concat = Iterables.concat(firstKey, secondKey); + Iterator<Iterator<WindowedValue<KV<Integer, Integer>>>> keySplit; + keySplit = TransformTranslator.splitBySameKey(concat.iterator(), coder, wvCoder); + + for (int i = 0; i < 2; i++) { + Iterator<WindowedValue<KV<Integer, Integer>>> iter = keySplit.next(); + List<WindowedValue<KV<Integer, Integer>>> list = new ArrayList<>(); + Iterators.addAll(list, iter); + if (i == 0) { + // first key + assertEquals( + Arrays.asList( + WindowedValue.of(KV.of(1, 1), now, window, paneInfo), + WindowedValue.of(KV.of(1, 2), now.plus(1), window, paneInfo)), + list); + } else { + // second key + assertEquals( + Arrays.asList( + WindowedValue.of(KV.of(2, 3), now, window, paneInfo), + WindowedValue.of(KV.of(2, 4), now.plus(2), window, paneInfo)), + list); + } + } + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AppliedPTransform.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AppliedPTransform.java index 79b550e..da7c720 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AppliedPTransform.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AppliedPTransform.java @@ -19,6 +19,7 @@ package org.apache.beam.sdk.runners; import com.google.auto.value.AutoValue; import java.util.Map; +import java.util.stream.Collectors; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.transforms.PTransform; @@ -70,4 +71,12 @@ public abstract class AppliedPTransform< public abstract TransformT getTransform(); public abstract Pipeline getPipeline(); + + /** @return map of {@link TupleTag TupleTags} which are not side inputs. */ + public Map<TupleTag<?>, PValue> getMainInputs() { + Map<TupleTag<?>, PValue> sideInputs = getTransform().getAdditionalInputs(); + return getInputs().entrySet().stream() + .filter(e -> !sideInputs.containsKey(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesRequiresTimeSortedInput.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesRequiresTimeSortedInput.java new file mode 100644 index 0000000..2b18391 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesRequiresTimeSortedInput.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.testing; + +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; + +/** + * Category tag for validation tests which utilize{@link DoFn.RequiresTimeSortedInput} in stateful + * {@link ParDo}. + */ +public @interface UsesRequiresTimeSortedInput {} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java index 61fb365..39fbe3b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java @@ -48,6 +48,7 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.WindowingStrategy; import org.joda.time.Duration; import org.joda.time.Instant; @@ -691,6 +692,32 @@ public abstract class DoFn<InputT, OutputT> implements Serializable, HasDisplayD public @interface RequiresStableInput {} /** + * <b><i>Experimental - no backwards compatibility guarantees. The exact name or usage of this + * feature may change.</i></b> + * + * <p>Annotation that may be added to a {@link ProcessElement} method to indicate that the runner + * must ensure that the observable contents of the input {@link PCollection} is sorted by time, in + * ascending order. The time ordering is defined by element's timestamp, ordering of elements with + * equal timestamps is not defined. + * + * <p>Note that this annotation makes sense only for stateful {@code ParDo}s, because outcome of + * stateless functions cannot depend on the ordering. + * + * <p>This annotation respects specified <i>allowedLateness</i> defined in {@link + * WindowingStrategy}. All data is emitted <b>after</b> input watermark passes element's timestamp + * + allowedLateness. Output watermark is hold, so that the emitted data is not emitted as late + * data. + * + * <p>The ordering requirements implies that all data that arrives later than the allowed lateness + * will have to be dropped. This might change in the future with introduction of retractions. + */ + @Documented + @Experimental + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.METHOD) + public @interface RequiresTimeSortedInput {} + + /** * Annotation for the method to use to finish processing a batch of elements. The method annotated * with this must satisfy the following constraints: * diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 1ea3547..58d1670 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -762,6 +762,12 @@ public abstract class DoFnSignature { */ public abstract boolean requiresStableInput(); + /** + * Whether this method requires time sorted input, expressed via {@link + * org.apache.beam.sdk.transforms.DoFn.RequiresTimeSortedInput}. + */ + public abstract boolean requiresTimeSortedInput(); + /** Concrete type of the {@link RestrictionTracker} parameter, if present. */ @Nullable public abstract TypeDescriptor<?> trackerT(); @@ -778,6 +784,7 @@ public abstract class DoFnSignature { Method targetMethod, List<Parameter> extraParameters, boolean requiresStableInput, + boolean requiresTimeSortedInput, TypeDescriptor<?> trackerT, @Nullable TypeDescriptor<? extends BoundedWindow> windowT, boolean hasReturnValue) { @@ -785,6 +792,7 @@ public abstract class DoFnSignature { targetMethod, Collections.unmodifiableList(extraParameters), requiresStableInput, + requiresTimeSortedInput, trackerT, windowT, hasReturnValue); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index 3c5faf4..590ffcb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -976,6 +976,7 @@ public class DoFnSignatures { MethodAnalysisContext methodContext = MethodAnalysisContext.create(); boolean requiresStableInput = m.isAnnotationPresent(DoFn.RequiresStableInput.class); + boolean requiresTimeSortedInput = m.isAnnotationPresent(DoFn.RequiresTimeSortedInput.class); Type[] params = m.getGenericParameterTypes(); @@ -1025,6 +1026,7 @@ public class DoFnSignatures { m, methodContext.getExtraParameters(), requiresStableInput, + requiresTimeSortedInput, trackerT, windowT, DoFn.ProcessContinuation.class.equals(m.getReturnType())); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 455761f..28f4fa4 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -17,7 +17,6 @@ */ package org.apache.beam.sdk.transforms; -import static junit.framework.TestCase.assertTrue; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasKey; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasType; @@ -37,6 +36,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import com.fasterxml.jackson.annotation.JsonCreator; @@ -90,6 +90,7 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.testing.UsesMapState; +import org.apache.beam.sdk.testing.UsesRequiresTimeSortedInput; import org.apache.beam.sdk.testing.UsesSetState; import org.apache.beam.sdk.testing.UsesSideInputs; import org.apache.beam.sdk.testing.UsesSideInputsWithDifferentCoders; @@ -124,6 +125,7 @@ import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; @@ -2367,6 +2369,193 @@ public class ParDoTest implements Serializable { return getClass().hashCode(); } } + + @Test + @Category({ + ValidatesRunner.class, + UsesStatefulParDo.class, + UsesRequiresTimeSortedInput.class, + UsesStrictTimerOrdering.class + }) + public void testRequiresTimeSortedInput() { + // generate list long enough to rule out random shuffle in sorted order + int numElements = 1000; + List<Long> eventStamps = + LongStream.range(0, numElements) + .mapToObj(i -> numElements - i) + .collect(Collectors.toList()); + testTimeSortedInput(numElements, pipeline.apply(Create.of(eventStamps))); + } + + @Test + @Category({ + ValidatesRunner.class, + UsesStatefulParDo.class, + UsesRequiresTimeSortedInput.class, + UsesStrictTimerOrdering.class, + UsesTestStream.class + }) + public void testRequiresTimeSortedInputWithTestStream() { + // generate list long enough to rule out random shuffle in sorted order + int numElements = 1000; + List<Long> eventStamps = + LongStream.range(0, numElements) + .mapToObj(i -> numElements - i) + .collect(Collectors.toList()); + TestStream.Builder<Long> stream = TestStream.create(VarLongCoder.of()); + for (Long stamp : eventStamps) { + stream = stream.addElements(stamp); + } + testTimeSortedInput(numElements, pipeline.apply(stream.advanceWatermarkToInfinity())); + } + + @Test + @Category({ + ValidatesRunner.class, + UsesStatefulParDo.class, + UsesRequiresTimeSortedInput.class, + UsesStrictTimerOrdering.class, + UsesTestStream.class + }) + public void testRequiresTimeSortedInputWithLateData() { + // generate list long enough to rule out random shuffle in sorted order + int numElements = 1000; + List<Long> eventStamps = + LongStream.range(0, numElements) + .mapToObj(i -> numElements - i) + .collect(Collectors.toList()); + TestStream.Builder<Long> input = TestStream.create(VarLongCoder.of()); + for (Long stamp : eventStamps) { + input = input.addElements(TimestampedValue.of(stamp, Instant.ofEpochMilli(stamp))); + if (stamp == 100) { + // advance watermark when we have 100 remaining elements + // all the rest are going to be late elements + input = input.advanceWatermarkTo(Instant.ofEpochMilli(stamp)); + } + } + testTimeSortedInput( + numElements - 100, + numElements - 1, + pipeline.apply(input.advanceWatermarkToInfinity()), + // cannot validate exactly which data gets dropped, because that is runner dependent + false); + } + + @Test + @Category({ + ValidatesRunner.class, + UsesStatefulParDo.class, + UsesRequiresTimeSortedInput.class, + UsesStrictTimerOrdering.class, + UsesTestStream.class + }) + public void testTwoRequiresTimeSortedInputWithLateData() { + // generate list long enough to rule out random shuffle in sorted order + int numElements = 1000; + List<Long> eventStamps = + LongStream.range(0, numElements) + .mapToObj(i -> numElements - i) + .collect(Collectors.toList()); + TestStream.Builder<Long> input = TestStream.create(VarLongCoder.of()); + for (Long stamp : eventStamps) { + input = input.addElements(TimestampedValue.of(stamp, Instant.ofEpochMilli(stamp))); + if (stamp == 100) { + // advance watermark when we have 100 remaining elements + // all the rest are going to be late elements + input = input.advanceWatermarkTo(Instant.ofEpochMilli(stamp)); + } + } + // apply the sorted function for the first time + PCollection<Long> first = + pipeline + .apply(input.advanceWatermarkToInfinity()) + .apply(WithTimestamps.of(e -> Instant.ofEpochMilli(e))) + .apply( + "first.MapElements", + MapElements.into( + TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.longs())) + .via(e -> KV.of("", e))) + .apply("first.ParDo", ParDo.of(timeSortedDoFn())) + .apply(MapElements.into(TypeDescriptors.longs()).via(e -> (long) e)); + // apply the test to the already sorted outcome so that we test that we don't loose any + // more data + testTimeSortedInputAlreadyHavingStamps( + numElements - 100, + numElements - 1, + first, + // cannot validate exactly which data gets dropped, because that is runner dependent + false); + } + + private static void testTimeSortedInput(int exactNumExpectedElements, PCollection<Long> input) { + testTimeSortedInput(exactNumExpectedElements, exactNumExpectedElements, input, true); + } + + private static void testTimeSortedInput( + int minNumExpectedElements, + int maxNumExpectedElements, + PCollection<Long> input, + boolean validateContents) { + testTimeSortedInputAlreadyHavingStamps( + minNumExpectedElements, + maxNumExpectedElements, + input.apply(WithTimestamps.of(e -> Instant.ofEpochMilli(e))), + validateContents); + } + + private static void testTimeSortedInputAlreadyHavingStamps( + int minNumExpectedElements, + int maxNumExpectedElements, + PCollection<Long> input, + boolean validateContents) { + + PCollection<Integer> output = + input + .apply( + "sorted.MapElements", + MapElements.into( + TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.longs())) + .via(e -> KV.of("", e))) + .apply("sorted.ParDo", ParDo.of(timeSortedDoFn())); + PAssert.that(output) + .satisfies( + values -> { + // validate that sum equals count, so that the whole list is made of ones + long numElements = StreamSupport.stream(values.spliterator(), false).count(); + assertTrue( + "Expected at least " + minNumExpectedElements + ", got " + numElements, + minNumExpectedElements <= numElements); + assertTrue( + "Expected at most " + maxNumExpectedElements + ", got " + numElements, + maxNumExpectedElements >= numElements); + if (validateContents) { + assertFalse( + "Expected all ones in " + values, + StreamSupport.stream(values.spliterator(), false).anyMatch(e -> e != 1)); + } + return null; + }); + input.getPipeline().run(); + } + + private static DoFn<KV<String, Long>, Integer> timeSortedDoFn() { + return new DoFn<KV<String, Long>, Integer>() { + + @StateId("last") + private final StateSpec<ValueState<Long>> lastSpec = StateSpecs.value(); + + @RequiresTimeSortedInput + @ProcessElement + public void process( + @Element KV<String, Long> element, + @StateId("last") ValueState<Long> last, + OutputReceiver<Integer> output) { + long lastVal = MoreObjects.firstNonNull(last.read(), element.getValue() - 1); + last.write(element.getValue()); + output.output((int) (element.getValue() - lastVal)); + } + }; + } } /** Tests for state coder inference behaviors. */