This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 8fbad485688 Change FnApiDoFnRunner to skip trySplit checkpoint
requests if not draining and nothing has yet been claimed by the tracker.
(#32044)
8fbad485688 is described below
commit 8fbad48568833d60a5244d00f4b4b943d82bac0b
Author: Sam Whittle <[email protected]>
AuthorDate: Wed Aug 14 15:26:12 2024 +0200
Change FnApiDoFnRunner to skip trySplit checkpoint requests if not draining
and nothing has yet been claimed by the tracker. (#32044)
---
.../apache/beam/fn/harness/FnApiDoFnRunner.java | 57 ++-
.../beam/fn/harness/FnApiDoFnRunnerTest.java | 465 +++++++++++++++++++--
2 files changed, 485 insertions(+), 37 deletions(-)
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index f85622ab89f..c39722c90d8 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -34,6 +34,7 @@ import java.util.List;
import java.util.Map;
import java.util.NavigableSet;
import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
@@ -118,6 +119,7 @@ import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.util.Durations;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
@@ -327,6 +329,11 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
* otherwise.
*/
private RestrictionTracker<RestrictionT, PositionT> currentTracker;
+ /**
+ * If non-null, set to true after currentTracker has had a tryClaim issued
on it. Used to ignore
+ * checkpoint split requests if no progress was made.
+ */
+ private @Nullable AtomicBoolean currentTrackerClaimed;
/**
* Only valid during {@link #processTimer} and {@link
#processOnWindowExpiration}, null otherwise.
@@ -877,12 +884,18 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
currentElement = elem.withValue(elem.getValue().getKey());
currentRestriction = elem.getValue().getValue().getKey();
currentWatermarkEstimatorState = elem.getValue().getValue().getValue();
+ currentTrackerClaimed = new AtomicBoolean(false);
currentTracker =
RestrictionTrackers.observe(
doFnInvoker.invokeNewTracker(processContext),
new ClaimObserver<PositionT>() {
+ private final AtomicBoolean claimed =
+ Preconditions.checkNotNull(currentTrackerClaimed);
+
@Override
- public void onClaimed(PositionT position) {}
+ public void onClaimed(PositionT position) {
+ claimed.lazySet(true);
+ }
@Override
public void onClaimFailed(PositionT position) {}
@@ -894,6 +907,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
currentRestriction = null;
currentWatermarkEstimatorState = null;
currentTracker = null;
+ currentTrackerClaimed = null;
}
this.stateAccessor.finalizeState();
@@ -909,12 +923,18 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
(Iterator<BoundedWindow>) elem.getWindows().iterator();
while (windowIterator.hasNext()) {
currentWindow = windowIterator.next();
+ currentTrackerClaimed = new AtomicBoolean(false);
currentTracker =
RestrictionTrackers.observe(
doFnInvoker.invokeNewTracker(processContext),
new ClaimObserver<PositionT>() {
+ private final AtomicBoolean claimed =
+ Preconditions.checkNotNull(currentTrackerClaimed);
+
@Override
- public void onClaimed(PositionT position) {}
+ public void onClaimed(PositionT position) {
+ claimed.lazySet(true);
+ }
@Override
public void onClaimFailed(PositionT position) {}
@@ -927,6 +947,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
currentWatermarkEstimatorState = null;
currentWindow = null;
currentTracker = null;
+ currentTrackerClaimed = null;
}
this.stateAccessor.finalizeState();
@@ -937,6 +958,8 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
currentElement = elem.withValue(elem.getValue().getKey().getKey());
currentRestriction = elem.getValue().getKey().getValue().getKey();
currentWatermarkEstimatorState =
elem.getValue().getKey().getValue().getValue();
+ // For truncation, we don't set currentTrackerClaimed so that we enable
checkpointing even if no
+ // progress is made.
currentTracker =
RestrictionTrackers.observe(
doFnInvoker.invokeNewTracker(processContext),
@@ -989,6 +1012,8 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
currentRestriction = elem.getValue().getKey().getValue().getKey();
currentWatermarkEstimatorState =
elem.getValue().getKey().getValue().getValue();
currentWindow = currentWindows.get(windowCurrentIndex);
+ // We leave currentTrackerClaimed unset as we want to split regardless
of if tryClaim is
+ // called.
currentTracker =
RestrictionTrackers.observe(
doFnInvoker.invokeNewTracker(processContext),
@@ -1081,12 +1106,18 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
currentRestriction = elem.getValue().getKey().getValue().getKey();
currentWatermarkEstimatorState =
elem.getValue().getKey().getValue().getValue();
currentWindow = currentWindows.get(windowCurrentIndex);
+ currentTrackerClaimed = new AtomicBoolean(false);
currentTracker =
RestrictionTrackers.observe(
doFnInvoker.invokeNewTracker(processContext),
new ClaimObserver<PositionT>() {
+ private final AtomicBoolean claimed =
+ Preconditions.checkNotNull(currentTrackerClaimed);
+
@Override
- public void onClaimed(PositionT position) {}
+ public void onClaimed(PositionT position) {
+ claimed.lazySet(true);
+ }
@Override
public void onClaimFailed(PositionT position) {}
@@ -1107,7 +1138,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
// Attempt to checkpoint the current restriction.
HandlesSplits.SplitResult splitResult =
- trySplitForElementAndRestriction(0, continuation.resumeDelay());
+ trySplitForElementAndRestriction(0, continuation.resumeDelay(),
false);
/**
* After the user has chosen to resume processing later, either the
restriction is already
@@ -1132,7 +1163,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
implements HandlesSplits, FnDataReceiver<WindowedValue> {
@Override
public HandlesSplits.SplitResult trySplit(double fractionOfRemainder) {
- return trySplitForElementAndRestriction(fractionOfRemainder,
Duration.ZERO);
+ return trySplitForElementAndRestriction(fractionOfRemainder,
Duration.ZERO, true);
}
@Override
@@ -1278,6 +1309,13 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
if (currentWindow == null) {
return null;
}
+ // We are requesting a checkpoint but have not yet progressed on the
restriction, skip
+ // request.
+ if (fractionOfRemainder == 0
+ && currentTrackerClaimed != null
+ && !currentTrackerClaimed.get()) {
+ return null;
+ }
SplitResultsWithStopIndex splitResult =
computeSplitForProcessOrTruncate(
@@ -1620,7 +1658,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
}
private HandlesSplits.SplitResult trySplitForElementAndRestriction(
- double fractionOfRemainder, Duration resumeDelay) {
+ double fractionOfRemainder, Duration resumeDelay, boolean
requireClaimForCheckpoint) {
KV<Instant, WatermarkEstimatorStateT> watermarkAndState;
WindowedSplitResult windowedSplitResult = null;
synchronized (splitLock) {
@@ -1628,6 +1666,13 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
if (currentTracker == null) {
return null;
}
+ // The tracker has not yet been claimed meaning that a checkpoint won't
meaningfully advance.
+ if (fractionOfRemainder == 0
+ && requireClaimForCheckpoint
+ && currentTrackerClaimed != null
+ && !currentTrackerClaimed.get()) {
+ return null;
+ }
// Make sure to get the output watermark before we split to ensure that
the lower bound
// applies to the residual.
watermarkAndState = currentWatermarkEstimator.getWatermarkAndState();
diff --git
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index 11f25ab0116..f4d555dabcc 100644
---
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -24,6 +24,7 @@ import static
org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.containsString;
@@ -53,6 +54,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.FnApiDoFnRunner.SplitResultsWithStopIndex;
import org.apache.beam.fn.harness.FnApiDoFnRunner.WindowedSplitResult;
@@ -151,6 +153,7 @@ import org.hamcrest.collection.IsMapContaining;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.joda.time.format.PeriodFormat;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
@@ -1370,41 +1373,83 @@ public class FnApiDoFnRunnerTest implements
Serializable {
* <li>splitting thread: {@link
*
NonWindowObservingTestSplittableDoFn#waitForSplitElementToBeProcessed()}
* <li>process element thread: {@link
- *
NonWindowObservingTestSplittableDoFn#enableAndWaitForTrySplitToHappen()}
+ * NonWindowObservingTestSplittableDoFn#splitElementProcessed()}
* <li>splitting thread: perform try split
- * <li>splitting thread: {@link
- *
NonWindowObservingTestSplittableDoFn#releaseWaitingProcessElementThread()}
+ * <li>splitting thread: {@link
NonWindowObservingTestSplittableDoFn#trySplitPerformed()} *
+ * <li>process element thread: {@link
+ * NonWindowObservingTestSplittableDoFn#waitForTrySplitPerformed()}
* </ul>
*/
static class NonWindowObservingTestSplittableDoFn extends DoFn<String,
String> {
- private static final ConcurrentMap<String, KV<CountDownLatch,
CountDownLatch>>
- DOFN_INSTANCE_TO_LOCK = new ConcurrentHashMap<>();
+ private static final ConcurrentMap<String, Latches>
DOFN_INSTANCE_TO_LATCHES =
+ new ConcurrentHashMap<>();
private static final long SPLIT_ELEMENT = 3;
private static final long CHECKPOINT_UPPER_BOUND = 8;
- private KV<CountDownLatch, CountDownLatch> getLatches() {
- return DOFN_INSTANCE_TO_LOCK.computeIfAbsent(
- this.uuid, (uuid) -> KV.of(new CountDownLatch(1), new
CountDownLatch(1)));
+ static class Latches {
+ public Latches() {}
+
+ CountDownLatch blockProcessLatch = new CountDownLatch(0);
+ CountDownLatch processEnteredLatch = new CountDownLatch(1);
+ CountDownLatch splitElementProcessedLatch = new CountDownLatch(1);
+ CountDownLatch trySplitPerformedLatch = new CountDownLatch(1);
+ AtomicBoolean abortProcessing = new AtomicBoolean();
+ }
+
+ private Latches getLatches() {
+ return DOFN_INSTANCE_TO_LATCHES.computeIfAbsent(this.uuid, (uuid) ->
new Latches());
+ }
+
+ public void splitElementProcessed() {
+ getLatches().splitElementProcessedLatch.countDown();
}
- public void enableAndWaitForTrySplitToHappen() throws Exception {
- KV<CountDownLatch, CountDownLatch> latches = getLatches();
- latches.getKey().countDown();
- if (!latches.getValue().await(30, TimeUnit.SECONDS)) {
+ public void waitForSplitElementToBeProcessed() throws
InterruptedException {
+ if (!getLatches().splitElementProcessedLatch.await(30,
TimeUnit.SECONDS)) {
fail("Failed to wait for trySplit to occur.");
}
}
- public void waitForSplitElementToBeProcessed() throws Exception {
- KV<CountDownLatch, CountDownLatch> latches = getLatches();
- if (!latches.getKey().await(30, TimeUnit.SECONDS)) {
- fail("Failed to wait for split element to be processed.");
+ public void trySplitPerformed() {
+ getLatches().trySplitPerformedLatch.countDown();
+ }
+
+ public void waitForTrySplitPerformed() throws InterruptedException {
+ if (!getLatches().trySplitPerformedLatch.await(30, TimeUnit.SECONDS)) {
+ fail("Failed to wait for trySplit to occur.");
}
}
- public void releaseWaitingProcessElementThread() {
- KV<CountDownLatch, CountDownLatch> latches = getLatches();
- latches.getValue().countDown();
+ // Must be called before process is invoked. Will prevent process from
doing anything until
+ // unblockProcess is
+ // called.
+ public void setupBlockProcess() {
+ getLatches().blockProcessLatch = new CountDownLatch(1);
+ }
+
+ public void enterProcessAndBlockIfEnabled() throws InterruptedException {
+ getLatches().processEnteredLatch.countDown();
+ if (!getLatches().blockProcessLatch.await(30, TimeUnit.SECONDS)) {
+ fail("Failed to wait for unblockProcess to occur.");
+ }
+ }
+
+ public void waitForProcessEntered() throws InterruptedException {
+ if (!getLatches().processEnteredLatch.await(5, TimeUnit.SECONDS)) {
+ fail("Failed to wait for process to begin.");
+ }
+ }
+
+ public void unblockProcess() throws InterruptedException {
+ getLatches().blockProcessLatch.countDown();
+ }
+
+ public void setAbortProcessing() {
+ getLatches().abortProcessing.set(true);
+ }
+
+ public boolean shouldAbortProcessing() {
+ return getLatches().abortProcessing.get();
}
private final String uuid;
@@ -1421,13 +1466,14 @@ public class FnApiDoFnRunnerTest implements
Serializable {
throws Exception {
long checkpointUpperBound = CHECKPOINT_UPPER_BOUND;
long position = tracker.currentRestriction().getFrom();
- boolean claimStatus;
- while (true) {
+ boolean claimStatus = true;
+ while (!shouldAbortProcessing()) {
claimStatus = tracker.tryClaim(position);
if (!claimStatus) {
break;
} else if (position == SPLIT_ELEMENT) {
- enableAndWaitForTrySplitToHappen();
+ splitElementProcessed();
+ waitForTrySplitPerformed();
}
context.outputWithTimestamp(
context.element() + ":" + position,
@@ -1511,15 +1557,17 @@ public class FnApiDoFnRunnerTest implements
Serializable {
RestrictionTracker<OffsetRange, Long> tracker,
ManualWatermarkEstimator<Instant> watermarkEstimator)
throws Exception {
+ enterProcessAndBlockIfEnabled();
long checkpointUpperBound =
Long.parseLong(context.sideInput(singletonSideInput));
long position = tracker.currentRestriction().getFrom();
- boolean claimStatus;
- while (true) {
+ boolean claimStatus = true;
+ while (!shouldAbortProcessing()) {
claimStatus = tracker.tryClaim(position);
if (!claimStatus) {
break;
} else if (position ==
NonWindowObservingTestSplittableDoFn.SPLIT_ELEMENT) {
- enableAndWaitForTrySplitToHappen();
+ splitElementProcessed();
+ waitForTrySplitPerformed();
}
context.outputWithTimestamp(
context.element() + ":" + position,
@@ -1549,7 +1597,8 @@ public class FnApiDoFnRunnerTest implements Serializable {
throws Exception {
// Waiting for split when we are on the second window.
if (splitAtTruncate && processedWindowCount == PROCESSED_WINDOW) {
- enableAndWaitForTrySplitToHappen();
+ splitElementProcessed();
+ waitForTrySplitPerformed();
}
processedWindowCount += 1;
return TruncateResult.of(new OffsetRange(range.getFrom(),
range.getTo() / 2));
@@ -1755,7 +1804,217 @@ public class FnApiDoFnRunnerTest implements
Serializable {
return ((HandlesSplits) mainInput).trySplit(0);
} finally {
- doFn.releaseWaitingProcessElementThread();
+ doFn.trySplitPerformed();
+ }
+ });
+
+ // Check that before processing an element we don't report progress
+ assertNoReportedProgress(context.getBundleProgressReporters());
+ mainInput.accept(
+ valueInGlobalWindow(
+ KV.of(
+ KV.of("7", KV.of(new OffsetRange(0, 5),
GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 2.0)));
+ HandlesSplits.SplitResult trySplitResult = trySplitFuture.get();
+
+ // Check that after processing an element we don't report progress
+ assertNoReportedProgress(context.getBundleProgressReporters());
+
+ // Since the SPLIT_ELEMENT is 3 we will process 0, 1, 2, 3 then be
split.
+ // We expect that the watermark advances to MIN + 2 since the manual
watermark estimator
+ // has yet to be invoked for the split element and that the primary
represents [0, 4) with
+ // the original watermark while the residual represents [4, 5) with
the new MIN + 2
+ // watermark.
+ assertThat(
+ mainOutputValues,
+ contains(
+ timestampedValueInGlobalWindow(
+ "7:0",
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(0))),
+ timestampedValueInGlobalWindow(
+ "7:1",
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(1))),
+ timestampedValueInGlobalWindow(
+ "7:2",
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(2))),
+ timestampedValueInGlobalWindow(
+ "7:3",
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(3)))));
+
+ BundleApplication primaryRoot =
Iterables.getOnlyElement(trySplitResult.getPrimaryRoots());
+ DelayedBundleApplication residualRoot =
+ Iterables.getOnlyElement(trySplitResult.getResidualRoots());
+ assertEquals(ParDoTranslation.getMainInputName(pTransform),
primaryRoot.getInputId());
+ assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+ assertEquals(
+ ParDoTranslation.getMainInputName(pTransform),
+ residualRoot.getApplication().getInputId());
+ assertEquals(TEST_TRANSFORM_ID,
residualRoot.getApplication().getTransformId());
+ assertEquals(
+ valueInGlobalWindow(
+ KV.of(
+ KV.of("7", KV.of(new OffsetRange(0, 4),
GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 4.0)),
+ inputCoder.decode(primaryRoot.getElement().newInput()));
+ assertEquals(
+ valueInGlobalWindow(
+ KV.of(
+ KV.of(
+ "7",
+ KV.of(
+ new OffsetRange(4, 5),
+
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(2)))),
+ 1.0)),
+
inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
+ Instant expectedOutputWatermark =
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(2));
+ assertEquals(
+ ImmutableMap.of(
+ "output",
+
org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.Timestamp.newBuilder()
+ .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+ .setNanos((int) (expectedOutputWatermark.getMillis() %
1000) * 1000000)
+ .build()),
+ residualRoot.getApplication().getOutputWatermarksMap());
+ // We expect 0 resume delay.
+ assertEquals(
+ residualRoot.getRequestedTimeDelay().getDefaultInstanceForType(),
+ residualRoot.getRequestedTimeDelay());
+ // We don't expect the outputs to goto the SDK initiated checkpointing
listener.
+ assertTrue(splitListener.getPrimaryRoots().isEmpty());
+ assertTrue(splitListener.getResidualRoots().isEmpty());
+ mainOutputValues.clear();
+ executorService.shutdown();
+ }
+
+ Iterables.getOnlyElement(context.getFinishBundleFunctions()).run();
+ assertThat(mainOutputValues, empty());
+
+ Iterables.getOnlyElement(context.getTearDownFunctions()).run();
+ assertThat(mainOutputValues, empty());
+
+ // Assert that state data did not change
+ assertEquals(
+ new FakeBeamFnStateClient(StringUtf8Coder.of(), stateData).getData(),
+ fakeClient.getData());
+ }
+
+ @Test
+ public void
testProcessElementForSizedElementAndRestrictionSplitBeforeTryClaim()
+ throws Exception {
+ Pipeline p = Pipeline.create();
+ addExperiment(p.getOptions().as(ExperimentalOptions.class),
"beam_fn_api");
+ // TODO(BEAM-10097): Remove experiment once all portable runners support
this view type
+ addExperiment(p.getOptions().as(ExperimentalOptions.class),
"use_runner_v2");
+ PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+ PCollectionView<String> singletonSideInputView =
valuePCollection.apply(View.asSingleton());
+ WindowObservingTestSplittableDoFn doFn =
+ new WindowObservingTestSplittableDoFn(singletonSideInputView);
+ valuePCollection.apply(
+ TEST_TRANSFORM_ID,
ParDo.of(doFn).withSideInputs(singletonSideInputView));
+
+ RunnerApi.Pipeline pProto =
+ ProtoOverrides.updateTransform(
+ PTransformTranslation.PAR_DO_TRANSFORM_URN,
+ PipelineTranslation.toProto(p,
SdkComponents.create(p.getOptions()), true),
+ SplittableParDoExpander.createSizedReplacement());
+ String expandedTransformId =
+ Iterables.find(
+ pProto.getComponents().getTransformsMap().entrySet(),
+ entry ->
+ entry
+ .getValue()
+ .getSpec()
+ .getUrn()
+ .equals(
+ PTransformTranslation
+
.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
+ &&
entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+ .getKey();
+ RunnerApi.PTransform pTransform =
+ pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+ String inputPCollectionId =
+
pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+ RunnerApi.PCollection inputPCollection =
+ pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
+ RehydratedComponents rehydratedComponents =
+ RehydratedComponents.forComponents(pProto.getComponents());
+ Coder<WindowedValue> inputCoder =
+ WindowedValue.getFullCoder(
+ CoderTranslation.fromProto(
+
pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
+ rehydratedComponents,
+ TranslationContext.DEFAULT),
+ (Coder)
+ CoderTranslation.fromProto(
+ pProto
+ .getComponents()
+ .getCodersOrThrow(
+ pProto
+ .getComponents()
+ .getWindowingStrategiesOrThrow(
+
inputPCollection.getWindowingStrategyId())
+ .getWindowCoderId()),
+ rehydratedComponents,
+ TranslationContext.DEFAULT));
+ String outputPCollectionId = pTransform.getOutputsOrThrow("output");
+
+ ImmutableMap<StateKey, List<String>> stateData =
+ ImmutableMap.of(
+ iterableSideInputKey(
+ singletonSideInputView.getTagInternal().getId(),
ByteString.EMPTY),
+ asList("8"));
+
+ FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(StringUtf8Coder.of(), stateData);
+
+ BundleSplitListener.InMemory splitListener =
BundleSplitListener.InMemory.create();
+
+ PTransformRunnerFactoryTestContext context =
+ PTransformRunnerFactoryTestContext.builder(TEST_TRANSFORM_ID,
pTransform)
+ .beamFnStateClient(fakeClient)
+ .processBundleInstructionId("57")
+
.pCollections(pProto.getComponentsOrBuilder().getPcollectionsMap())
+ .coders(pProto.getComponents().getCodersMap())
+
.windowingStrategies(pProto.getComponents().getWindowingStrategiesMap())
+ .splitListener(splitListener)
+ .build();
+ List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+ context.addPCollectionConsumer(
+ outputPCollectionId,
+ (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>)
mainOutputValues::add);
+
+ new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(context);
+
+ Iterables.getOnlyElement(context.getStartBundleFunctions()).run();
+ mainOutputValues.clear();
+
+ assertThat(
+ context.getPCollectionConsumers().keySet(),
+ containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+ FnDataReceiver<WindowedValue<?>> mainInput =
+ context.getPCollectionConsumer(inputPCollectionId);
+ assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+ doFn.setupBlockProcess();
+ {
+ // Setup and launch the trySplit thread.
+ ExecutorService executorService = Executors.newSingleThreadExecutor();
+ Future<HandlesSplits.SplitResult> trySplitFuture =
+ executorService.submit(
+ () -> {
+ try {
+ // Verify that a split before anything is claimed is
ignored.
+ doFn.waitForProcessEntered();
+ Assert.assertNull(((HandlesSplits) mainInput).trySplit(0));
+ doFn.unblockProcess();
+
+ doFn.waitForSplitElementToBeProcessed();
+ // Currently processing "3" out of range [0, 5) elements.
+ assertEquals(0.6, ((HandlesSplits)
mainInput).getProgress(), 0.01);
+
+ // Check that during progressing of an element we report
progress
+ assertReportedProgressEquals(
+ context.getShortIdMap(),
context.getBundleProgressReporters(), 3.0, 2.0);
+
+ return ((HandlesSplits) mainInput).trySplit(0);
+ } finally {
+ doFn.trySplitPerformed();
}
});
@@ -1845,6 +2104,149 @@ public class FnApiDoFnRunnerTest implements
Serializable {
fakeClient.getData());
}
+ @Test
+ public void testProcessElementForSizedElementAndRestrictionNoTryClaim()
throws Exception {
+ Pipeline p = Pipeline.create();
+ addExperiment(p.getOptions().as(ExperimentalOptions.class),
"beam_fn_api");
+ // TODO(BEAM-10097): Remove experiment once all portable runners support
this view type
+ addExperiment(p.getOptions().as(ExperimentalOptions.class),
"use_runner_v2");
+ PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+ PCollectionView<String> singletonSideInputView =
valuePCollection.apply(View.asSingleton());
+ WindowObservingTestSplittableDoFn doFn =
+ new WindowObservingTestSplittableDoFn(singletonSideInputView);
+ doFn.setAbortProcessing();
+ valuePCollection.apply(
+ TEST_TRANSFORM_ID,
ParDo.of(doFn).withSideInputs(singletonSideInputView));
+
+ RunnerApi.Pipeline pProto =
+ ProtoOverrides.updateTransform(
+ PTransformTranslation.PAR_DO_TRANSFORM_URN,
+ PipelineTranslation.toProto(p,
SdkComponents.create(p.getOptions()), true),
+ SplittableParDoExpander.createSizedReplacement());
+ String expandedTransformId =
+ Iterables.find(
+ pProto.getComponents().getTransformsMap().entrySet(),
+ entry ->
+ entry
+ .getValue()
+ .getSpec()
+ .getUrn()
+ .equals(
+ PTransformTranslation
+
.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
+ &&
entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+ .getKey();
+ RunnerApi.PTransform pTransform =
+ pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+ String inputPCollectionId =
+
pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+ RunnerApi.PCollection inputPCollection =
+ pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
+ RehydratedComponents rehydratedComponents =
+ RehydratedComponents.forComponents(pProto.getComponents());
+ Coder<WindowedValue> inputCoder =
+ WindowedValue.getFullCoder(
+ CoderTranslation.fromProto(
+
pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
+ rehydratedComponents,
+ TranslationContext.DEFAULT),
+ (Coder)
+ CoderTranslation.fromProto(
+ pProto
+ .getComponents()
+ .getCodersOrThrow(
+ pProto
+ .getComponents()
+ .getWindowingStrategiesOrThrow(
+
inputPCollection.getWindowingStrategyId())
+ .getWindowCoderId()),
+ rehydratedComponents,
+ TranslationContext.DEFAULT));
+ String outputPCollectionId = pTransform.getOutputsOrThrow("output");
+
+ ImmutableMap<StateKey, List<String>> stateData =
+ ImmutableMap.of(
+ iterableSideInputKey(
+ singletonSideInputView.getTagInternal().getId(),
ByteString.EMPTY),
+ asList("8"));
+
+ FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(StringUtf8Coder.of(), stateData);
+
+ BundleSplitListener.InMemory splitListener =
BundleSplitListener.InMemory.create();
+
+ PTransformRunnerFactoryTestContext context =
+ PTransformRunnerFactoryTestContext.builder(TEST_TRANSFORM_ID,
pTransform)
+ .beamFnStateClient(fakeClient)
+ .processBundleInstructionId("57")
+
.pCollections(pProto.getComponentsOrBuilder().getPcollectionsMap())
+ .coders(pProto.getComponents().getCodersMap())
+
.windowingStrategies(pProto.getComponents().getWindowingStrategiesMap())
+ .splitListener(splitListener)
+ .build();
+ List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+ context.addPCollectionConsumer(
+ outputPCollectionId,
+ (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>)
mainOutputValues::add);
+
+ new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(context);
+
+ Iterables.getOnlyElement(context.getStartBundleFunctions()).run();
+ mainOutputValues.clear();
+
+ assertThat(
+ context.getPCollectionConsumers().keySet(),
+ containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+ FnDataReceiver<WindowedValue<?>> mainInput =
+ context.getPCollectionConsumer(inputPCollectionId);
+ assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+ {
+ // Check that before processing an element we don't report progress
+ assertNoReportedProgress(context.getBundleProgressReporters());
+ mainInput.accept(
+ valueInGlobalWindow(
+ KV.of(
+ KV.of("5", KV.of(new OffsetRange(5, 10),
GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 5.0)));
+ // Check that after processing an element we don't report progress
+ assertNoReportedProgress(context.getBundleProgressReporters());
+
+ // Since we set abort processing above, we expect the input
restriction to be output with a
+ // resume
+ // delay.
+ BundleApplication primaryRoot =
Iterables.getOnlyElement(splitListener.getPrimaryRoots());
+ DelayedBundleApplication residualRoot =
+ Iterables.getOnlyElement(splitListener.getResidualRoots());
+ assertEquals(ParDoTranslation.getMainInputName(pTransform),
primaryRoot.getInputId());
+ assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+ assertEquals(
+ ParDoTranslation.getMainInputName(pTransform),
+ residualRoot.getApplication().getInputId());
+ assertEquals(TEST_TRANSFORM_ID,
residualRoot.getApplication().getTransformId());
+ assertEquals(
+ valueInGlobalWindow(
+ KV.of(
+ KV.of("5", KV.of(new OffsetRange(5, 5),
GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 0.0)),
+ inputCoder.decode(primaryRoot.getElement().newInput()));
+ assertEquals(
+ valueInGlobalWindow(
+ KV.of(
+ KV.of("5", KV.of(new OffsetRange(5, 10),
GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 5.0)),
+
inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
+ assertThat(residualRoot.getApplication().getOutputWatermarksMap(),
anEmptyMap());
+ assertEquals(
+
org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.Duration.newBuilder()
+ .setSeconds(54)
+ .setNanos(321000000)
+ .build(),
+ residualRoot.getRequestedTimeDelay());
+ splitListener.clear();
+ }
+ }
+
private static final MonitoringInfo WORK_COMPLETED_MI =
MonitoringInfo.newBuilder()
.setUrn(MonitoringInfoConstants.Urns.WORK_COMPLETED)
@@ -2187,7 +2589,7 @@ public class FnApiDoFnRunnerTest implements Serializable {
return ((HandlesSplits) mainInput).trySplit(0);
} finally {
- doFn.releaseWaitingProcessElementThread();
+ doFn.trySplitPerformed();
}
});
@@ -3143,10 +3545,11 @@ public class FnApiDoFnRunnerTest implements
Serializable {
() -> {
try {
doFn.waitForSplitElementToBeProcessed();
-
- return ((HandlesSplits) mainInput).trySplit(0);
+ HandlesSplits.SplitResult result = ((HandlesSplits)
mainInput).trySplit(0);
+ Assert.assertNotNull(result);
+ return result;
} finally {
- doFn.releaseWaitingProcessElementThread();
+ doFn.trySplitPerformed();
}
});