This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 8fbad485688 Change FnApiDoFnRunner to skip trySplit checkpoint 
requests if not draining and nothing has yet been claimed by the tracker. 
(#32044)
8fbad485688 is described below

commit 8fbad48568833d60a5244d00f4b4b943d82bac0b
Author: Sam Whittle <[email protected]>
AuthorDate: Wed Aug 14 15:26:12 2024 +0200

    Change FnApiDoFnRunner to skip trySplit checkpoint requests if not draining 
and nothing has yet been claimed by the tracker. (#32044)
---
 .../apache/beam/fn/harness/FnApiDoFnRunner.java    |  57 ++-
 .../beam/fn/harness/FnApiDoFnRunnerTest.java       | 465 +++++++++++++++++++--
 2 files changed, 485 insertions(+), 37 deletions(-)

diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index f85622ab89f..c39722c90d8 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -34,6 +34,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.NavigableSet;
 import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.BiConsumer;
 import java.util.function.BiFunction;
 import java.util.function.Consumer;
@@ -118,6 +119,7 @@ import org.apache.beam.sdk.values.WindowingStrategy;
 import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.util.Durations;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
@@ -327,6 +329,11 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
    * otherwise.
    */
   private RestrictionTracker<RestrictionT, PositionT> currentTracker;
+  /**
+   * If non-null, set to true after currentTracker has had a tryClaim issued 
on it. Used to ignore
+   * checkpoint split requests if no progress was made.
+   */
+  private @Nullable AtomicBoolean currentTrackerClaimed;
 
   /**
    * Only valid during {@link #processTimer} and {@link 
#processOnWindowExpiration}, null otherwise.
@@ -877,12 +884,18 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
     currentElement = elem.withValue(elem.getValue().getKey());
     currentRestriction = elem.getValue().getValue().getKey();
     currentWatermarkEstimatorState = elem.getValue().getValue().getValue();
+    currentTrackerClaimed = new AtomicBoolean(false);
     currentTracker =
         RestrictionTrackers.observe(
             doFnInvoker.invokeNewTracker(processContext),
             new ClaimObserver<PositionT>() {
+              private final AtomicBoolean claimed =
+                  Preconditions.checkNotNull(currentTrackerClaimed);
+
               @Override
-              public void onClaimed(PositionT position) {}
+              public void onClaimed(PositionT position) {
+                claimed.lazySet(true);
+              }
 
               @Override
               public void onClaimFailed(PositionT position) {}
@@ -894,6 +907,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
       currentRestriction = null;
       currentWatermarkEstimatorState = null;
       currentTracker = null;
+      currentTrackerClaimed = null;
     }
 
     this.stateAccessor.finalizeState();
@@ -909,12 +923,18 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
           (Iterator<BoundedWindow>) elem.getWindows().iterator();
       while (windowIterator.hasNext()) {
         currentWindow = windowIterator.next();
+        currentTrackerClaimed = new AtomicBoolean(false);
         currentTracker =
             RestrictionTrackers.observe(
                 doFnInvoker.invokeNewTracker(processContext),
                 new ClaimObserver<PositionT>() {
+                  private final AtomicBoolean claimed =
+                      Preconditions.checkNotNull(currentTrackerClaimed);
+
                   @Override
-                  public void onClaimed(PositionT position) {}
+                  public void onClaimed(PositionT position) {
+                    claimed.lazySet(true);
+                  }
 
                   @Override
                   public void onClaimFailed(PositionT position) {}
@@ -927,6 +947,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
       currentWatermarkEstimatorState = null;
       currentWindow = null;
       currentTracker = null;
+      currentTrackerClaimed = null;
     }
 
     this.stateAccessor.finalizeState();
@@ -937,6 +958,8 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
     currentElement = elem.withValue(elem.getValue().getKey().getKey());
     currentRestriction = elem.getValue().getKey().getValue().getKey();
     currentWatermarkEstimatorState = 
elem.getValue().getKey().getValue().getValue();
+    // For truncation, we don't set currentTrackerClaimed so that we enable 
checkpointing even if no
+    // progress is made.
     currentTracker =
         RestrictionTrackers.observe(
             doFnInvoker.invokeNewTracker(processContext),
@@ -989,6 +1012,8 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
         currentRestriction = elem.getValue().getKey().getValue().getKey();
         currentWatermarkEstimatorState = 
elem.getValue().getKey().getValue().getValue();
         currentWindow = currentWindows.get(windowCurrentIndex);
+        // We leave currentTrackerClaimed unset as we want to split regardless 
of if tryClaim is
+        // called.
         currentTracker =
             RestrictionTrackers.observe(
                 doFnInvoker.invokeNewTracker(processContext),
@@ -1081,12 +1106,18 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
         currentRestriction = elem.getValue().getKey().getValue().getKey();
         currentWatermarkEstimatorState = 
elem.getValue().getKey().getValue().getValue();
         currentWindow = currentWindows.get(windowCurrentIndex);
+        currentTrackerClaimed = new AtomicBoolean(false);
         currentTracker =
             RestrictionTrackers.observe(
                 doFnInvoker.invokeNewTracker(processContext),
                 new ClaimObserver<PositionT>() {
+                  private final AtomicBoolean claimed =
+                      Preconditions.checkNotNull(currentTrackerClaimed);
+
                   @Override
-                  public void onClaimed(PositionT position) {}
+                  public void onClaimed(PositionT position) {
+                    claimed.lazySet(true);
+                  }
 
                   @Override
                   public void onClaimFailed(PositionT position) {}
@@ -1107,7 +1138,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
 
       // Attempt to checkpoint the current restriction.
       HandlesSplits.SplitResult splitResult =
-          trySplitForElementAndRestriction(0, continuation.resumeDelay());
+          trySplitForElementAndRestriction(0, continuation.resumeDelay(), 
false);
 
       /**
        * After the user has chosen to resume processing later, either the 
restriction is already
@@ -1132,7 +1163,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
       implements HandlesSplits, FnDataReceiver<WindowedValue> {
     @Override
     public HandlesSplits.SplitResult trySplit(double fractionOfRemainder) {
-      return trySplitForElementAndRestriction(fractionOfRemainder, 
Duration.ZERO);
+      return trySplitForElementAndRestriction(fractionOfRemainder, 
Duration.ZERO, true);
     }
 
     @Override
@@ -1278,6 +1309,13 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
       if (currentWindow == null) {
         return null;
       }
+      // We are requesting a checkpoint but have not yet progressed on the 
restriction, skip
+      // request.
+      if (fractionOfRemainder == 0
+          && currentTrackerClaimed != null
+          && !currentTrackerClaimed.get()) {
+        return null;
+      }
 
       SplitResultsWithStopIndex splitResult =
           computeSplitForProcessOrTruncate(
@@ -1620,7 +1658,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
   }
 
   private HandlesSplits.SplitResult trySplitForElementAndRestriction(
-      double fractionOfRemainder, Duration resumeDelay) {
+      double fractionOfRemainder, Duration resumeDelay, boolean 
requireClaimForCheckpoint) {
     KV<Instant, WatermarkEstimatorStateT> watermarkAndState;
     WindowedSplitResult windowedSplitResult = null;
     synchronized (splitLock) {
@@ -1628,6 +1666,13 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
       if (currentTracker == null) {
         return null;
       }
+      // The tracker has not yet been claimed meaning that a checkpoint won't 
meaningfully advance.
+      if (fractionOfRemainder == 0
+          && requireClaimForCheckpoint
+          && currentTrackerClaimed != null
+          && !currentTrackerClaimed.get()) {
+        return null;
+      }
       // Make sure to get the output watermark before we split to ensure that 
the lower bound
       // applies to the residual.
       watermarkAndState = currentWatermarkEstimator.getWatermarkAndState();
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index 11f25ab0116..f4d555dabcc 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -24,6 +24,7 @@ import static 
org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
 import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.anEmptyMap;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.containsString;
@@ -53,6 +54,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Supplier;
 import org.apache.beam.fn.harness.FnApiDoFnRunner.SplitResultsWithStopIndex;
 import org.apache.beam.fn.harness.FnApiDoFnRunner.WindowedSplitResult;
@@ -151,6 +153,7 @@ import org.hamcrest.collection.IsMapContaining;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.joda.time.format.PeriodFormat;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Ignore;
 import org.junit.Rule;
@@ -1370,41 +1373,83 @@ public class FnApiDoFnRunnerTest implements 
Serializable {
      *   <li>splitting thread: {@link
      *       
NonWindowObservingTestSplittableDoFn#waitForSplitElementToBeProcessed()}
      *   <li>process element thread: {@link
-     *       
NonWindowObservingTestSplittableDoFn#enableAndWaitForTrySplitToHappen()}
+     *       NonWindowObservingTestSplittableDoFn#splitElementProcessed()}
      *   <li>splitting thread: perform try split
-     *   <li>splitting thread: {@link
-     *       
NonWindowObservingTestSplittableDoFn#releaseWaitingProcessElementThread()}
+     *   <li>splitting thread: {@link 
NonWindowObservingTestSplittableDoFn#trySplitPerformed()} *
+     *   <li>process element thread: {@link
+     *       NonWindowObservingTestSplittableDoFn#waitForTrySplitPerformed()}
      * </ul>
      */
     static class NonWindowObservingTestSplittableDoFn extends DoFn<String, 
String> {
-      private static final ConcurrentMap<String, KV<CountDownLatch, 
CountDownLatch>>
-          DOFN_INSTANCE_TO_LOCK = new ConcurrentHashMap<>();
+      private static final ConcurrentMap<String, Latches> 
DOFN_INSTANCE_TO_LATCHES =
+          new ConcurrentHashMap<>();
       private static final long SPLIT_ELEMENT = 3;
       private static final long CHECKPOINT_UPPER_BOUND = 8;
 
-      private KV<CountDownLatch, CountDownLatch> getLatches() {
-        return DOFN_INSTANCE_TO_LOCK.computeIfAbsent(
-            this.uuid, (uuid) -> KV.of(new CountDownLatch(1), new 
CountDownLatch(1)));
+      static class Latches {
+        public Latches() {}
+
+        CountDownLatch blockProcessLatch = new CountDownLatch(0);
+        CountDownLatch processEnteredLatch = new CountDownLatch(1);
+        CountDownLatch splitElementProcessedLatch = new CountDownLatch(1);
+        CountDownLatch trySplitPerformedLatch = new CountDownLatch(1);
+        AtomicBoolean abortProcessing = new AtomicBoolean();
+      }
+
+      private Latches getLatches() {
+        return DOFN_INSTANCE_TO_LATCHES.computeIfAbsent(this.uuid, (uuid) -> 
new Latches());
+      }
+
+      public void splitElementProcessed() {
+        getLatches().splitElementProcessedLatch.countDown();
       }
 
-      public void enableAndWaitForTrySplitToHappen() throws Exception {
-        KV<CountDownLatch, CountDownLatch> latches = getLatches();
-        latches.getKey().countDown();
-        if (!latches.getValue().await(30, TimeUnit.SECONDS)) {
+      public void waitForSplitElementToBeProcessed() throws 
InterruptedException {
+        if (!getLatches().splitElementProcessedLatch.await(30, 
TimeUnit.SECONDS)) {
           fail("Failed to wait for trySplit to occur.");
         }
       }
 
-      public void waitForSplitElementToBeProcessed() throws Exception {
-        KV<CountDownLatch, CountDownLatch> latches = getLatches();
-        if (!latches.getKey().await(30, TimeUnit.SECONDS)) {
-          fail("Failed to wait for split element to be processed.");
+      public void trySplitPerformed() {
+        getLatches().trySplitPerformedLatch.countDown();
+      }
+
+      public void waitForTrySplitPerformed() throws InterruptedException {
+        if (!getLatches().trySplitPerformedLatch.await(30, TimeUnit.SECONDS)) {
+          fail("Failed to wait for trySplit to occur.");
         }
       }
 
-      public void releaseWaitingProcessElementThread() {
-        KV<CountDownLatch, CountDownLatch> latches = getLatches();
-        latches.getValue().countDown();
+      // Must be called before process is invoked. Will prevent process from 
doing anything until
+      // unblockProcess is
+      // called.
+      public void setupBlockProcess() {
+        getLatches().blockProcessLatch = new CountDownLatch(1);
+      }
+
+      public void enterProcessAndBlockIfEnabled() throws InterruptedException {
+        getLatches().processEnteredLatch.countDown();
+        if (!getLatches().blockProcessLatch.await(30, TimeUnit.SECONDS)) {
+          fail("Failed to wait for unblockProcess to occur.");
+        }
+      }
+
+      public void waitForProcessEntered() throws InterruptedException {
+        if (!getLatches().processEnteredLatch.await(5, TimeUnit.SECONDS)) {
+          fail("Failed to wait for process to begin.");
+        }
+      }
+
+      public void unblockProcess() throws InterruptedException {
+        getLatches().blockProcessLatch.countDown();
+      }
+
+      public void setAbortProcessing() {
+        getLatches().abortProcessing.set(true);
+      }
+
+      public boolean shouldAbortProcessing() {
+        return getLatches().abortProcessing.get();
       }
 
       private final String uuid;
@@ -1421,13 +1466,14 @@ public class FnApiDoFnRunnerTest implements 
Serializable {
           throws Exception {
         long checkpointUpperBound = CHECKPOINT_UPPER_BOUND;
         long position = tracker.currentRestriction().getFrom();
-        boolean claimStatus;
-        while (true) {
+        boolean claimStatus = true;
+        while (!shouldAbortProcessing()) {
           claimStatus = tracker.tryClaim(position);
           if (!claimStatus) {
             break;
           } else if (position == SPLIT_ELEMENT) {
-            enableAndWaitForTrySplitToHappen();
+            splitElementProcessed();
+            waitForTrySplitPerformed();
           }
           context.outputWithTimestamp(
               context.element() + ":" + position,
@@ -1511,15 +1557,17 @@ public class FnApiDoFnRunnerTest implements 
Serializable {
           RestrictionTracker<OffsetRange, Long> tracker,
           ManualWatermarkEstimator<Instant> watermarkEstimator)
           throws Exception {
+        enterProcessAndBlockIfEnabled();
         long checkpointUpperBound = 
Long.parseLong(context.sideInput(singletonSideInput));
         long position = tracker.currentRestriction().getFrom();
-        boolean claimStatus;
-        while (true) {
+        boolean claimStatus = true;
+        while (!shouldAbortProcessing()) {
           claimStatus = tracker.tryClaim(position);
           if (!claimStatus) {
             break;
           } else if (position == 
NonWindowObservingTestSplittableDoFn.SPLIT_ELEMENT) {
-            enableAndWaitForTrySplitToHappen();
+            splitElementProcessed();
+            waitForTrySplitPerformed();
           }
           context.outputWithTimestamp(
               context.element() + ":" + position,
@@ -1549,7 +1597,8 @@ public class FnApiDoFnRunnerTest implements Serializable {
           throws Exception {
         // Waiting for split when we are on the second window.
         if (splitAtTruncate && processedWindowCount == PROCESSED_WINDOW) {
-          enableAndWaitForTrySplitToHappen();
+          splitElementProcessed();
+          waitForTrySplitPerformed();
         }
         processedWindowCount += 1;
         return TruncateResult.of(new OffsetRange(range.getFrom(), 
range.getTo() / 2));
@@ -1755,7 +1804,217 @@ public class FnApiDoFnRunnerTest implements 
Serializable {
 
                     return ((HandlesSplits) mainInput).trySplit(0);
                   } finally {
-                    doFn.releaseWaitingProcessElementThread();
+                    doFn.trySplitPerformed();
+                  }
+                });
+
+        // Check that before processing an element we don't report progress
+        assertNoReportedProgress(context.getBundleProgressReporters());
+        mainInput.accept(
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of("7", KV.of(new OffsetRange(0, 5), 
GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    2.0)));
+        HandlesSplits.SplitResult trySplitResult = trySplitFuture.get();
+
+        // Check that after processing an element we don't report progress
+        assertNoReportedProgress(context.getBundleProgressReporters());
+
+        // Since the SPLIT_ELEMENT is 3 we will process 0, 1, 2, 3 then be 
split.
+        // We expect that the watermark advances to MIN + 2 since the manual 
watermark estimator
+        // has yet to be invoked for the split element and that the primary 
represents [0, 4) with
+        // the original watermark while the residual represents [4, 5) with 
the new MIN + 2
+        // watermark.
+        assertThat(
+            mainOutputValues,
+            contains(
+                timestampedValueInGlobalWindow(
+                    "7:0", 
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(0))),
+                timestampedValueInGlobalWindow(
+                    "7:1", 
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(1))),
+                timestampedValueInGlobalWindow(
+                    "7:2", 
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(2))),
+                timestampedValueInGlobalWindow(
+                    "7:3", 
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(3)))));
+
+        BundleApplication primaryRoot = 
Iterables.getOnlyElement(trySplitResult.getPrimaryRoots());
+        DelayedBundleApplication residualRoot =
+            Iterables.getOnlyElement(trySplitResult.getResidualRoots());
+        assertEquals(ParDoTranslation.getMainInputName(pTransform), 
primaryRoot.getInputId());
+        assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+        assertEquals(
+            ParDoTranslation.getMainInputName(pTransform),
+            residualRoot.getApplication().getInputId());
+        assertEquals(TEST_TRANSFORM_ID, 
residualRoot.getApplication().getTransformId());
+        assertEquals(
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of("7", KV.of(new OffsetRange(0, 4), 
GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    4.0)),
+            inputCoder.decode(primaryRoot.getElement().newInput()));
+        assertEquals(
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of(
+                        "7",
+                        KV.of(
+                            new OffsetRange(4, 5),
+                            
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(2)))),
+                    1.0)),
+            
inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
+        Instant expectedOutputWatermark = 
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(2));
+        assertEquals(
+            ImmutableMap.of(
+                "output",
+                
org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.Timestamp.newBuilder()
+                    .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+                    .setNanos((int) (expectedOutputWatermark.getMillis() % 
1000) * 1000000)
+                    .build()),
+            residualRoot.getApplication().getOutputWatermarksMap());
+        // We expect 0 resume delay.
+        assertEquals(
+            residualRoot.getRequestedTimeDelay().getDefaultInstanceForType(),
+            residualRoot.getRequestedTimeDelay());
+        // We don't expect the outputs to goto the SDK initiated checkpointing 
listener.
+        assertTrue(splitListener.getPrimaryRoots().isEmpty());
+        assertTrue(splitListener.getResidualRoots().isEmpty());
+        mainOutputValues.clear();
+        executorService.shutdown();
+      }
+
+      Iterables.getOnlyElement(context.getFinishBundleFunctions()).run();
+      assertThat(mainOutputValues, empty());
+
+      Iterables.getOnlyElement(context.getTearDownFunctions()).run();
+      assertThat(mainOutputValues, empty());
+
+      // Assert that state data did not change
+      assertEquals(
+          new FakeBeamFnStateClient(StringUtf8Coder.of(), stateData).getData(),
+          fakeClient.getData());
+    }
+
+    @Test
+    public void 
testProcessElementForSizedElementAndRestrictionSplitBeforeTryClaim()
+        throws Exception {
+      Pipeline p = Pipeline.create();
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), 
"beam_fn_api");
+      // TODO(BEAM-10097): Remove experiment once all portable runners support 
this view type
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), 
"use_runner_v2");
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      PCollectionView<String> singletonSideInputView = 
valuePCollection.apply(View.asSingleton());
+      WindowObservingTestSplittableDoFn doFn =
+          new WindowObservingTestSplittableDoFn(singletonSideInputView);
+      valuePCollection.apply(
+          TEST_TRANSFORM_ID, 
ParDo.of(doFn).withSideInputs(singletonSideInputView));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, 
SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createSizedReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(
+                                  PTransformTranslation
+                                      
.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
+                          && 
entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          
pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      RunnerApi.PCollection inputPCollection =
+          pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
+      RehydratedComponents rehydratedComponents =
+          RehydratedComponents.forComponents(pProto.getComponents());
+      Coder<WindowedValue> inputCoder =
+          WindowedValue.getFullCoder(
+              CoderTranslation.fromProto(
+                  
pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
+                  rehydratedComponents,
+                  TranslationContext.DEFAULT),
+              (Coder)
+                  CoderTranslation.fromProto(
+                      pProto
+                          .getComponents()
+                          .getCodersOrThrow(
+                              pProto
+                                  .getComponents()
+                                  .getWindowingStrategiesOrThrow(
+                                      
inputPCollection.getWindowingStrategyId())
+                                  .getWindowCoderId()),
+                      rehydratedComponents,
+                      TranslationContext.DEFAULT));
+      String outputPCollectionId = pTransform.getOutputsOrThrow("output");
+
+      ImmutableMap<StateKey, List<String>> stateData =
+          ImmutableMap.of(
+              iterableSideInputKey(
+                  singletonSideInputView.getTagInternal().getId(), 
ByteString.EMPTY),
+              asList("8"));
+
+      FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(StringUtf8Coder.of(), stateData);
+
+      BundleSplitListener.InMemory splitListener = 
BundleSplitListener.InMemory.create();
+
+      PTransformRunnerFactoryTestContext context =
+          PTransformRunnerFactoryTestContext.builder(TEST_TRANSFORM_ID, 
pTransform)
+              .beamFnStateClient(fakeClient)
+              .processBundleInstructionId("57")
+              
.pCollections(pProto.getComponentsOrBuilder().getPcollectionsMap())
+              .coders(pProto.getComponents().getCodersMap())
+              
.windowingStrategies(pProto.getComponents().getWindowingStrategiesMap())
+              .splitListener(splitListener)
+              .build();
+      List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+      context.addPCollectionConsumer(
+          outputPCollectionId,
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) 
mainOutputValues::add);
+
+      new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(context);
+
+      Iterables.getOnlyElement(context.getStartBundleFunctions()).run();
+      mainOutputValues.clear();
+
+      assertThat(
+          context.getPCollectionConsumers().keySet(),
+          containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          context.getPCollectionConsumer(inputPCollectionId);
+      assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+      doFn.setupBlockProcess();
+      {
+        // Setup and launch the trySplit thread.
+        ExecutorService executorService = Executors.newSingleThreadExecutor();
+        Future<HandlesSplits.SplitResult> trySplitFuture =
+            executorService.submit(
+                () -> {
+                  try {
+                    // Verify that a split before anything is claimed is 
ignored.
+                    doFn.waitForProcessEntered();
+                    Assert.assertNull(((HandlesSplits) mainInput).trySplit(0));
+                    doFn.unblockProcess();
+
+                    doFn.waitForSplitElementToBeProcessed();
+                    // Currently processing "3" out of range [0, 5) elements.
+                    assertEquals(0.6, ((HandlesSplits) 
mainInput).getProgress(), 0.01);
+
+                    // Check that during progressing of an element we report 
progress
+                    assertReportedProgressEquals(
+                        context.getShortIdMap(), 
context.getBundleProgressReporters(), 3.0, 2.0);
+
+                    return ((HandlesSplits) mainInput).trySplit(0);
+                  } finally {
+                    doFn.trySplitPerformed();
                   }
                 });
 
@@ -1845,6 +2104,149 @@ public class FnApiDoFnRunnerTest implements 
Serializable {
           fakeClient.getData());
     }
 
+    @Test
+    public void testProcessElementForSizedElementAndRestrictionNoTryClaim() 
throws Exception {
+      Pipeline p = Pipeline.create();
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), 
"beam_fn_api");
+      // TODO(BEAM-10097): Remove experiment once all portable runners support 
this view type
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), 
"use_runner_v2");
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      PCollectionView<String> singletonSideInputView = 
valuePCollection.apply(View.asSingleton());
+      WindowObservingTestSplittableDoFn doFn =
+          new WindowObservingTestSplittableDoFn(singletonSideInputView);
+      doFn.setAbortProcessing();
+      valuePCollection.apply(
+          TEST_TRANSFORM_ID, 
ParDo.of(doFn).withSideInputs(singletonSideInputView));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, 
SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createSizedReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(
+                                  PTransformTranslation
+                                      
.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
+                          && 
entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          
pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      RunnerApi.PCollection inputPCollection =
+          pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
+      RehydratedComponents rehydratedComponents =
+          RehydratedComponents.forComponents(pProto.getComponents());
+      Coder<WindowedValue> inputCoder =
+          WindowedValue.getFullCoder(
+              CoderTranslation.fromProto(
+                  
pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
+                  rehydratedComponents,
+                  TranslationContext.DEFAULT),
+              (Coder)
+                  CoderTranslation.fromProto(
+                      pProto
+                          .getComponents()
+                          .getCodersOrThrow(
+                              pProto
+                                  .getComponents()
+                                  .getWindowingStrategiesOrThrow(
+                                      
inputPCollection.getWindowingStrategyId())
+                                  .getWindowCoderId()),
+                      rehydratedComponents,
+                      TranslationContext.DEFAULT));
+      String outputPCollectionId = pTransform.getOutputsOrThrow("output");
+
+      ImmutableMap<StateKey, List<String>> stateData =
+          ImmutableMap.of(
+              iterableSideInputKey(
+                  singletonSideInputView.getTagInternal().getId(), 
ByteString.EMPTY),
+              asList("8"));
+
+      FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(StringUtf8Coder.of(), stateData);
+
+      BundleSplitListener.InMemory splitListener = 
BundleSplitListener.InMemory.create();
+
+      PTransformRunnerFactoryTestContext context =
+          PTransformRunnerFactoryTestContext.builder(TEST_TRANSFORM_ID, 
pTransform)
+              .beamFnStateClient(fakeClient)
+              .processBundleInstructionId("57")
+              
.pCollections(pProto.getComponentsOrBuilder().getPcollectionsMap())
+              .coders(pProto.getComponents().getCodersMap())
+              
.windowingStrategies(pProto.getComponents().getWindowingStrategiesMap())
+              .splitListener(splitListener)
+              .build();
+      List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+      context.addPCollectionConsumer(
+          outputPCollectionId,
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) 
mainOutputValues::add);
+
+      new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(context);
+
+      Iterables.getOnlyElement(context.getStartBundleFunctions()).run();
+      mainOutputValues.clear();
+
+      assertThat(
+          context.getPCollectionConsumers().keySet(),
+          containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          context.getPCollectionConsumer(inputPCollectionId);
+      assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+      {
+        // Check that before processing an element we don't report progress
+        assertNoReportedProgress(context.getBundleProgressReporters());
+        mainInput.accept(
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of("5", KV.of(new OffsetRange(5, 10), 
GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    5.0)));
+        // Check that after processing an element we don't report progress
+        assertNoReportedProgress(context.getBundleProgressReporters());
+
+        // Since we set abort processing above, we expect the input 
restriction to be output with a
+        // resume
+        // delay.
+        BundleApplication primaryRoot = 
Iterables.getOnlyElement(splitListener.getPrimaryRoots());
+        DelayedBundleApplication residualRoot =
+            Iterables.getOnlyElement(splitListener.getResidualRoots());
+        assertEquals(ParDoTranslation.getMainInputName(pTransform), 
primaryRoot.getInputId());
+        assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+        assertEquals(
+            ParDoTranslation.getMainInputName(pTransform),
+            residualRoot.getApplication().getInputId());
+        assertEquals(TEST_TRANSFORM_ID, 
residualRoot.getApplication().getTransformId());
+        assertEquals(
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of("5", KV.of(new OffsetRange(5, 5), 
GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    0.0)),
+            inputCoder.decode(primaryRoot.getElement().newInput()));
+        assertEquals(
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of("5", KV.of(new OffsetRange(5, 10), 
GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    5.0)),
+            
inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
+        assertThat(residualRoot.getApplication().getOutputWatermarksMap(), 
anEmptyMap());
+        assertEquals(
+            
org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.Duration.newBuilder()
+                .setSeconds(54)
+                .setNanos(321000000)
+                .build(),
+            residualRoot.getRequestedTimeDelay());
+        splitListener.clear();
+      }
+    }
+
     private static final MonitoringInfo WORK_COMPLETED_MI =
         MonitoringInfo.newBuilder()
             .setUrn(MonitoringInfoConstants.Urns.WORK_COMPLETED)
@@ -2187,7 +2589,7 @@ public class FnApiDoFnRunnerTest implements Serializable {
 
                     return ((HandlesSplits) mainInput).trySplit(0);
                   } finally {
-                    doFn.releaseWaitingProcessElementThread();
+                    doFn.trySplitPerformed();
                   }
                 });
 
@@ -3143,10 +3545,11 @@ public class FnApiDoFnRunnerTest implements 
Serializable {
               () -> {
                 try {
                   doFn.waitForSplitElementToBeProcessed();
-
-                  return ((HandlesSplits) mainInput).trySplit(0);
+                  HandlesSplits.SplitResult result = ((HandlesSplits) 
mainInput).trySplit(0);
+                  Assert.assertNotNull(result);
+                  return result;
                 } finally {
-                  doFn.releaseWaitingProcessElementThread();
+                  doFn.trySplitPerformed();
                 }
               });
 


Reply via email to