This is an automated email from the ASF dual-hosted git repository.

chamikaramj pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 1c0c02472c7 Adds backlog reporting support for non-fnapi based SDF's. 
(#38346)
1c0c02472c7 is described below

commit 1c0c02472c784a69249d328c486c6eb8628185e7
Author: Andrew Crites <[email protected]>
AuthorDate: Thu May 21 13:40:16 2026 -0700

    Adds backlog reporting support for non-fnapi based SDF's. (#38346)
---
 ..._ValidatesRunner_Dataflow_Streaming_Engine.json |   2 +-
 ...TimeBoundedSplittableProcessElementInvoker.java | 247 ++++++++++++---------
 .../core/SplittableParDoViaKeyedWorkItems.java     |   9 +
 .../core/SplittableProcessElementInvoker.java      |  25 ++-
 .../org/apache/beam/runners/core/StepContext.java  |   6 +
 ...BoundedSplittableProcessElementInvokerTest.java |  47 +++-
 .../runners/core/SplittableParDoProcessFnTest.java | 124 ++++++++++-
 .../worker/SplittableProcessFnFactory.java         |   1 +
 .../worker/StreamingModeExecutionContext.java      |  16 ++
 .../worker/StreamingModeExecutionContextTest.java  |  25 +++
 .../SpannerChangeStreamErrorTest.java              |  41 +++-
 11 files changed, 415 insertions(+), 128 deletions(-)

diff --git 
a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json
 
b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json
index 50d17c108f2..e623d3373a9 100644
--- 
a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json
+++ 
b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json
@@ -1,4 +1,4 @@
 {
   "comment": "Modify this file in a trivial way to cause this test suite to 
run!",
-  "modification":  2,
+  "modification":  1,
 }
diff --git 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
index ebd88442b21..dbbcfe8ee31 100644
--- 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
+++ 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
@@ -130,114 +130,114 @@ public class 
OutputAndTimeBoundedSplittableProcessElementInvoker<
       final Map<String, PCollectionView<?>> sideInputMapping) {
     final ProcessContext processContext = new ProcessContext(element, tracker, 
watermarkEstimator);
 
-    DoFn.ProcessContinuation cont =
-        invoker.invokeProcessElement(
-            new DoFnInvoker.BaseArgumentProvider<InputT, OutputT>() {
-              @Override
-              public String getErrorContext() {
-                return 
OutputAndTimeBoundedSplittableProcessElementInvoker.class.getSimpleName();
-              }
-
-              @Override
-              public DoFn<InputT, OutputT>.ProcessContext processContext(
-                  DoFn<InputT, OutputT> doFn) {
-                return processContext;
-              }
-
-              @Override
-              public Object sideInput(String tagId) {
-                PCollectionView<?> view = sideInputMapping.get(tagId);
-                if (view == null) {
-                  throw new IllegalArgumentException("calling getSideInput() 
with unknown view");
-                }
-                return processContext.sideInput(view);
-              }
-
-              @Override
-              public Object restriction() {
-                return tracker.currentRestriction();
-              }
-
-              @Override
-              public InputT element(DoFn<InputT, OutputT> doFn) {
-                return processContext.element();
-              }
-
-              @Override
-              public Instant timestamp(DoFn<InputT, OutputT> doFn) {
-                return processContext.timestamp();
-              }
-
-              @Override
-              public String timerId(DoFn<InputT, OutputT> doFn) {
-                throw new UnsupportedOperationException(
-                    "Cannot access timerId as parameter outside of @OnTimer 
method.");
-              }
-
-              @Override
-              public TimeDomain timeDomain(DoFn<InputT, OutputT> doFn) {
-                throw new UnsupportedOperationException(
-                    "Access to time domain not supported in ProcessElement");
-              }
-
-              @Override
-              public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, 
OutputT> doFn) {
-                return DoFnOutputReceivers.windowedReceiver(
-                    processContext, 
OutputBuilderSuppliers.supplierForElement(element), null);
-              }
-
-              @Override
-              public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, 
OutputT> doFn) {
-                throw new UnsupportedOperationException("Not supported in 
SplittableDoFn");
-              }
-
-              @Override
-              public MultiOutputReceiver taggedOutputReceiver(DoFn<InputT, 
OutputT> doFn) {
-                return DoFnOutputReceivers.windowedMultiReceiver(
-                    processContext, 
OutputBuilderSuppliers.supplierForElement(element));
-              }
-
-              @Override
-              public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
-                return processContext.causedByDrain();
-              }
-
-              @Override
-              public RestrictionTracker<?, ?> restrictionTracker() {
-                return processContext.tracker;
-              }
-
-              @Override
-              public WatermarkEstimator<?> watermarkEstimator() {
-                return processContext.watermarkEstimator;
-              }
-
-              @Override
-              public PipelineOptions pipelineOptions() {
-                return pipelineOptions;
-              }
-
-              @Override
-              public BundleFinalizer bundleFinalizer() {
-                return bundleFinalizer.get();
-              }
-
-              // Unsupported methods below.
-
-              @Override
-              public StartBundleContext startBundleContext(DoFn<InputT, 
OutputT> doFn) {
-                throw new IllegalStateException(
-                    "Should not access startBundleContext() from @"
-                        + DoFn.ProcessElement.class.getSimpleName());
-              }
-
-              @Override
-              public FinishBundleContext finishBundleContext(DoFn<InputT, 
OutputT> doFn) {
-                throw new IllegalStateException(
-                    "Should not access finishBundleContext() from @"
-                        + DoFn.ProcessElement.class.getSimpleName());
-              }
-            });
+    DoFnInvoker.BaseArgumentProvider<InputT, OutputT> invokerArgumentProvider =
+        new DoFnInvoker.BaseArgumentProvider<InputT, OutputT>() {
+          @Override
+          public String getErrorContext() {
+            return 
OutputAndTimeBoundedSplittableProcessElementInvoker.class.getSimpleName();
+          }
+
+          @Override
+          public DoFn<InputT, OutputT>.ProcessContext 
processContext(DoFn<InputT, OutputT> doFn) {
+            return processContext;
+          }
+
+          @Override
+          public Object sideInput(String tagId) {
+            PCollectionView<?> view = sideInputMapping.get(tagId);
+            if (view == null) {
+              throw new IllegalArgumentException("calling getSideInput() with 
unknown view");
+            }
+            return processContext.sideInput(view);
+          }
+
+          @Override
+          public Object restriction() {
+            return tracker.currentRestriction();
+          }
+
+          @Override
+          public InputT element(DoFn<InputT, OutputT> doFn) {
+            return processContext.element();
+          }
+
+          @Override
+          public Instant timestamp(DoFn<InputT, OutputT> doFn) {
+            return processContext.timestamp();
+          }
+
+          @Override
+          public String timerId(DoFn<InputT, OutputT> doFn) {
+            throw new UnsupportedOperationException(
+                "Cannot access timerId as parameter outside of @OnTimer 
method.");
+          }
+
+          @Override
+          public TimeDomain timeDomain(DoFn<InputT, OutputT> doFn) {
+            throw new UnsupportedOperationException(
+                "Access to time domain not supported in ProcessElement");
+          }
+
+          @Override
+          public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> 
doFn) {
+            return DoFnOutputReceivers.windowedReceiver(
+                processContext, 
OutputBuilderSuppliers.supplierForElement(element), null);
+          }
+
+          @Override
+          public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> 
doFn) {
+            throw new UnsupportedOperationException("Not supported in 
SplittableDoFn");
+          }
+
+          @Override
+          public MultiOutputReceiver taggedOutputReceiver(DoFn<InputT, 
OutputT> doFn) {
+            return DoFnOutputReceivers.windowedMultiReceiver(
+                processContext, 
OutputBuilderSuppliers.supplierForElement(element));
+          }
+
+          @Override
+          public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+            return processContext.causedByDrain();
+          }
+
+          @Override
+          public RestrictionTracker<?, ?> restrictionTracker() {
+            return processContext.tracker;
+          }
+
+          @Override
+          public WatermarkEstimator<?> watermarkEstimator() {
+            return processContext.watermarkEstimator;
+          }
+
+          @Override
+          public PipelineOptions pipelineOptions() {
+            return pipelineOptions;
+          }
+
+          @Override
+          public BundleFinalizer bundleFinalizer() {
+            return bundleFinalizer.get();
+          }
+
+          // Unsupported methods below.
+
+          @Override
+          public StartBundleContext startBundleContext(DoFn<InputT, OutputT> 
doFn) {
+            throw new IllegalStateException(
+                "Should not access startBundleContext() from @"
+                    + DoFn.ProcessElement.class.getSimpleName());
+          }
+
+          @Override
+          public FinishBundleContext finishBundleContext(DoFn<InputT, OutputT> 
doFn) {
+            throw new IllegalStateException(
+                "Should not access finishBundleContext() from @"
+                    + DoFn.ProcessElement.class.getSimpleName());
+          }
+        };
+
+    DoFn.ProcessContinuation cont = 
invoker.invokeProcessElement(invokerArgumentProvider);
     processContext.cancelScheduledCheckpoint();
     @Nullable
     KV<RestrictionT, KV<Instant, WatermarkEstimatorStateT>> residual =
@@ -278,8 +278,37 @@ public class 
OutputAndTimeBoundedSplittableProcessElementInvoker<
     if (residual == null) {
       return new Result(null, cont, null, null);
     }
+    final KV<RestrictionT, KV<Instant, WatermarkEstimatorStateT>> 
residualForGetSize = residual;
+    // For a list of all DoFnInvoker arguments, see DoFn.java.
+    double backlogBytes =
+        invoker.invokeGetSize(
+            new DoFnInvoker.DelegatingArgumentProvider<InputT, OutputT>(
+                invokerArgumentProvider, 
invokerArgumentProvider.getErrorContext() + "/GetSize") {
+              @Override
+              public Object restriction() {
+                return residualForGetSize.getKey();
+              }
+
+              @Override
+              public RestrictionTracker<?, ?> restrictionTracker() {
+                return invoker.invokeNewTracker(
+                    new DoFnInvoker.DelegatingArgumentProvider<InputT, 
OutputT>(
+                        invokerArgumentProvider,
+                        invokerArgumentProvider.getErrorContext() + 
"/NewTracker") {
+
+                      @Override
+                      public Object restriction() {
+                        return residualForGetSize.getKey();
+                      }
+                    });
+              }
+            });
     return new Result(
-        residual.getKey(), cont, residual.getValue().getKey(), 
residual.getValue().getValue());
+        residual.getKey(),
+        cont,
+        residual.getValue().getKey(),
+        residual.getValue().getValue(),
+        backlogBytes);
   }
 
   private class ProcessContext extends DoFn<InputT, OutputT>.ProcessContext
diff --git 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
index 424ea567115..a750b01963f 100644
--- 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
+++ 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
@@ -22,6 +22,7 @@ import static 
org.apache.beam.sdk.util.construction.SplittableParDo.SPLITTABLE_P
 import com.google.auto.service.AutoService;
 import java.util.List;
 import java.util.Map;
+import java.util.function.Consumer;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.sdk.coders.ByteArrayCoder;
 import org.apache.beam.sdk.coders.Coder;
@@ -281,6 +282,7 @@ public class SplittableParDoViaKeyedWorkItems {
         processElementInvoker;
 
     private transient @Nullable DoFnInvoker<InputT, OutputT> invoker;
+    private transient @Nullable Consumer<Double> backlogBytesCallback;
 
     public ProcessFn(
         DoFn<InputT, OutputT> fn,
@@ -323,6 +325,10 @@ public class SplittableParDoViaKeyedWorkItems {
       this.processElementInvoker = invoker;
     }
 
+    public void setBacklogBytesCallback(Consumer<Double> backlogBytesCallback) 
{
+      this.backlogBytesCallback = backlogBytesCallback;
+    }
+
     public DoFn<InputT, OutputT> getFn() {
       return fn;
     }
@@ -624,6 +630,9 @@ public class SplittableParDoViaKeyedWorkItems {
       } else {
         holdState.clear();
       }
+      if (backlogBytesCallback != null && result.getBacklogBytes() >= 0) {
+        backlogBytesCallback.accept(result.getBacklogBytes());
+      }
     }
 
     private DoFnInvoker.ArgumentProvider<InputT, OutputT> wrapOptionsAsSetup(
diff --git 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java
 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java
index 1ff66d6e517..d311806e0a2 100644
--- 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java
+++ 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java
@@ -42,6 +42,10 @@ public abstract class SplittableProcessElementInvoker<
     private final DoFn.ProcessContinuation continuation;
     private final @Nullable Instant futureOutputWatermark;
     private final @Nullable WatermarkEstimatorStateT 
futureWatermarkEstimatorState;
+    private final double backlogBytes;
+
+    /* Constant representing an unknown amount of backlog. */
+    public static final double BACKLOG_UNKNOWN = -1.0;
 
     @SuppressFBWarnings(
         value = "NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE",
@@ -50,12 +54,27 @@ public abstract class SplittableProcessElementInvoker<
         @Nullable RestrictionT residualRestriction,
         DoFn.ProcessContinuation continuation,
         @Nullable Instant futureOutputWatermark,
-        @Nullable WatermarkEstimatorStateT futureWatermarkEstimatorState) {
+        @Nullable WatermarkEstimatorStateT futureWatermarkEstimatorState,
+        double backlogBytes) {
       checkArgument(continuation != null, "continuation must not be null");
       this.continuation = continuation;
       this.residualRestriction = residualRestriction;
       this.futureOutputWatermark = futureOutputWatermark;
       this.futureWatermarkEstimatorState = futureWatermarkEstimatorState;
+      this.backlogBytes = backlogBytes;
+    }
+
+    public Result(
+        @Nullable RestrictionT residualRestriction,
+        DoFn.ProcessContinuation continuation,
+        @Nullable Instant futureOutputWatermark,
+        @Nullable WatermarkEstimatorStateT futureWatermarkEstimatorState) {
+      this(
+          residualRestriction,
+          continuation,
+          futureOutputWatermark,
+          futureWatermarkEstimatorState,
+          BACKLOG_UNKNOWN);
     }
 
     /**
@@ -76,6 +95,10 @@ public abstract class SplittableProcessElementInvoker<
     public @Nullable WatermarkEstimatorStateT 
getFutureWatermarkEstimatorState() {
       return futureWatermarkEstimatorState;
     }
+
+    public double getBacklogBytes() {
+      return backlogBytes;
+    }
   }
 
   /**
diff --git 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/StepContext.java 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/StepContext.java
index d2a03ff6ab3..e07d1a10586 100644
--- 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/StepContext.java
+++ 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/StepContext.java
@@ -36,4 +36,10 @@ public interface StepContext {
   default BundleFinalizer bundleFinalizer() {
     throw new UnsupportedOperationException("BundleFinalizer is unsupported.");
   }
+
+  /**
+   * Set the current backlog bytes for this step. This is mainly used by 
splittable DoFn to report
+   * the size of the residual restriction.
+   */
+  default void setBacklogBytes(double backlogBytes) {}
 }
diff --git 
a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java
 
b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java
index 1750cceffa0..52ac6b1a819 100644
--- 
a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java
+++ 
b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java
@@ -95,6 +95,30 @@ public class 
OutputAndTimeBoundedSplittableProcessElementInvokerTest {
     }
   }
 
+  private static class GetSizeFn extends DoFn<Void, String> {
+    @ProcessElement
+    public ProcessContinuation process(
+        ProcessContext c, RestrictionTracker<OffsetRange, Long> tracker) {
+      for (long i = tracker.currentRestriction().getFrom(); 
tracker.tryClaim(i); ++i) {
+        c.output(String.valueOf(i));
+        if (i == 2) {
+          return resume();
+        }
+      }
+      return stop();
+    }
+
+    @GetInitialRestriction
+    public OffsetRange getInitialRestriction() {
+      return new OffsetRange(0, 10);
+    }
+
+    @GetSize
+    public double getSize(@Restriction OffsetRange range) {
+      return range.getTo() - range.getFrom();
+    }
+  }
+
   private SplittableProcessElementInvoker<Void, String, OffsetRange, Long, 
Void>.Result runTest(
       int totalNumOutputs,
       Duration sleepBeforeFirstClaim,
@@ -103,11 +127,12 @@ public class 
OutputAndTimeBoundedSplittableProcessElementInvokerTest {
       throws Exception {
     SomeFn fn = new SomeFn(sleepBeforeFirstClaim, numOutputsPerProcessCall, 
sleepBeforeEachOutput);
     OffsetRange initialRestriction = new OffsetRange(0, totalNumOutputs);
-    return runTest(fn, initialRestriction);
+    return runTest(fn, initialRestriction, Duration.standardSeconds(3));
   }
 
   private SplittableProcessElementInvoker<Void, String, OffsetRange, Long, 
Void>.Result runTest(
-      DoFn<Void, String> fn, OffsetRange initialRestriction) throws Exception {
+      DoFn<Void, String> fn, OffsetRange initialRestriction, Duration 
checkpointDuration)
+      throws Exception {
     SplittableProcessElementInvoker<Void, String, OffsetRange, Long, Void> 
invoker =
         new OutputAndTimeBoundedSplittableProcessElementInvoker<>(
             fn,
@@ -122,7 +147,7 @@ public class 
OutputAndTimeBoundedSplittableProcessElementInvokerTest {
             NullSideInputReader.empty(),
             Executors.newSingleThreadScheduledExecutor(),
             1000,
-            Duration.standardSeconds(3),
+            checkpointDuration,
             () -> {
               throw new UnsupportedOperationException("BundleFinalizer not 
configured for test.");
             });
@@ -215,7 +240,7 @@ public class 
OutputAndTimeBoundedSplittableProcessElementInvokerTest {
           }
         };
     e.expectMessage("Output is not allowed before tryClaim()");
-    runTest(brokenFn, new OffsetRange(0, 5));
+    runTest(brokenFn, new OffsetRange(0, 5), Duration.standardSeconds(3));
   }
 
   @Test
@@ -235,6 +260,18 @@ public class 
OutputAndTimeBoundedSplittableProcessElementInvokerTest {
           }
         };
     e.expectMessage("Output is not allowed after a failed tryClaim()");
-    runTest(brokenFn, new OffsetRange(0, 5));
+    runTest(brokenFn, new OffsetRange(0, 5), Duration.standardSeconds(3));
+  }
+
+  @Test
+  public void testBacklogBytes() throws Exception {
+    GetSizeFn fn = new GetSizeFn();
+    OffsetRange initialRestriction = new OffsetRange(0, 10);
+    // Set a high checkpoint duration to prevent flakiness caused by early 
checkpointing.
+    SplittableProcessElementInvoker<Void, String, OffsetRange, Long, 
Void>.Result res =
+        runTest(fn, initialRestriction, Duration.standardMinutes(3));
+    // GetSizeFn claims 3 elements and then takes a checkpoint.
+    assertEquals(7.0, res.getBacklogBytes(), 0.001);
+    assertEquals(new OffsetRange(3, 10), res.getResidualRestriction());
   }
 }
diff --git 
a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java
 
b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java
index ef1f201ca1e..381e41c9870 100644
--- 
a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java
+++ 
b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java
@@ -140,6 +140,8 @@ public class SplittableParDoProcessFnTest {
     private InMemoryTimerInternals timerInternals;
     private TestInMemoryStateInternals<String> stateInternals;
     private InMemoryBundleFinalizer bundleFinalizer;
+    private final ProcessFn<InputT, OutputT, RestrictionT, PositionT, 
WatermarkEstimatorStateT>
+        processFn;
 
     ProcessFnTester(
         Instant currentProcessingTime,
@@ -154,15 +156,14 @@ public class SplittableParDoProcessFnTest {
       // encode IntervalWindow's because that's what all tests here use.
       WindowingStrategy<InputT, BoundedWindow> windowingStrategy =
           (WindowingStrategy) 
WindowingStrategy.of(FixedWindows.of(Duration.standardSeconds(1)));
-      final ProcessFn<InputT, OutputT, RestrictionT, PositionT, 
WatermarkEstimatorStateT>
-          processFn =
-              new ProcessFn<>(
-                  fn,
-                  inputCoder,
-                  restrictionCoder,
-                  watermarkEstimatorStateCoder,
-                  windowingStrategy,
-                  Collections.emptyMap());
+      this.processFn =
+          new ProcessFn<>(
+              fn,
+              inputCoder,
+              restrictionCoder,
+              watermarkEstimatorStateCoder,
+              windowingStrategy,
+              Collections.emptyMap());
       this.tester = DoFnTester.of(processFn);
       this.timerInternals = new InMemoryTimerInternals();
       this.stateInternals = new TestInMemoryStateInternals<>("dummy");
@@ -386,6 +387,61 @@ public class SplittableParDoProcessFnTest {
     }
   }
 
+  private static class GetSizeFn extends DoFn<Integer, String> {
+    @ProcessElement
+    public ProcessContinuation process(
+        ProcessContext c, RestrictionTracker<OffsetRange, Long> tracker) {
+      for (long i = tracker.currentRestriction().getFrom(); 
tracker.tryClaim(i); ++i) {
+        c.output(String.valueOf(i));
+        if (i == 2) {
+          return resume();
+        }
+      }
+      return stop();
+    }
+
+    @GetInitialRestriction
+    public OffsetRange getInitialRestriction() {
+      return new OffsetRange(0, 10);
+    }
+
+    @NewTracker
+    public OffsetRangeTracker newTracker(@Restriction OffsetRange range) {
+      return new OffsetRangeTracker(range);
+    }
+
+    @GetSize
+    public double getSize(@Restriction OffsetRange range) {
+      return range.getTo() - range.getFrom();
+    }
+  }
+
+  // Used to check that backlog can be computed from the restriction tracker 
if GetSize is not
+  // defined.
+  private static class SdfWithoutGetSize extends DoFn<Integer, String> {
+    @ProcessElement
+    public ProcessContinuation process(
+        ProcessContext c, RestrictionTracker<OffsetRange, Long> tracker) {
+      for (long i = tracker.currentRestriction().getFrom(); 
tracker.tryClaim(i); ++i) {
+        c.output(String.valueOf(i));
+        if (i == 2) {
+          return resume();
+        }
+      }
+      return stop();
+    }
+
+    @GetInitialRestriction
+    public OffsetRange getInitialRestriction() {
+      return new OffsetRange(0, 10);
+    }
+
+    @NewTracker
+    public OffsetRangeTracker newTracker(@Restriction OffsetRange range) {
+      return new OffsetRangeTracker(range);
+    }
+  }
+
   @Test
   public void testDrains() throws Exception {
     DoFn<Instant, String> fn = new WatermarkUpdateFn();
@@ -684,4 +740,54 @@ public class SplittableParDoProcessFnTest {
       tester.startElement(42, new SomeRestriction());
     }
   }
+
+  @Test
+  public void testReportsBacklog() throws Exception {
+    DoFn<Integer, String> fn = new GetSizeFn();
+    Instant base = Instant.now();
+    final List<Double> backlogs = new ArrayList<>();
+
+    try (ProcessFnTester<Integer, String, OffsetRange, Long, Void> tester =
+        new ProcessFnTester<>(
+            base,
+            fn,
+            BigEndianIntegerCoder.of(),
+            SerializableCoder.of(OffsetRange.class),
+            VoidCoder.of(),
+            MAX_OUTPUTS_PER_BUNDLE,
+            MAX_BUNDLE_DURATION)) {
+      tester.processFn.setBacklogBytesCallback(backlogs::add);
+
+      tester.startElement(42, new OffsetRange(0, 10));
+      // First call outputs 0, 1, and 2, and then resumes.
+      // The residual range should be [3, 10), so size is 7.
+      assertEquals(1, backlogs.size());
+      assertEquals(7.0, backlogs.get(0), 0.001);
+    }
+  }
+
+  @Test
+  public void testReportsBacklogWithoutGetSize() throws Exception {
+    DoFn<Integer, String> fn = new SdfWithoutGetSize();
+    Instant base = Instant.now();
+    final List<Double> backlogs = new ArrayList<>();
+
+    try (ProcessFnTester<Integer, String, OffsetRange, Long, Void> tester =
+        new ProcessFnTester<>(
+            base,
+            fn,
+            BigEndianIntegerCoder.of(),
+            SerializableCoder.of(OffsetRange.class),
+            VoidCoder.of(),
+            MAX_OUTPUTS_PER_BUNDLE,
+            MAX_BUNDLE_DURATION)) {
+      tester.processFn.setBacklogBytesCallback(backlogs::add);
+
+      tester.startElement(42, new OffsetRange(0, 10));
+      // First call outputs 0, 1, and 2, and then resumes.
+      // The residual range should be [3, 10), so size is 7.
+      assertEquals(1, backlogs.size());
+      assertEquals(7.0, backlogs.get(0), 0.001);
+    }
+  }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SplittableProcessFnFactory.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SplittableProcessFnFactory.java
index 93c288fea9e..3ad443ee2a2 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SplittableProcessFnFactory.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SplittableProcessFnFactory.java
@@ -157,6 +157,7 @@ class SplittableProcessFnFactory {
               10000,
               Duration.standardSeconds(10),
               stepContext::bundleFinalizer));
+      processFn.setBacklogBytesCallback(userStepContext::setBacklogBytes);
       DoFnRunner<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> 
simpleRunner =
           new SimpleDoFnRunner<>(
               options,
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
index f75d452b211..e1f1b21e135 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
@@ -254,6 +254,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
             : WindmillTagEncodingV1.instance();
     this.outputBuilder = outputBuilder;
     this.sideInputCache.clear();
+    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
     clearSinkFullHint();
 
     Instant processingTime = 
computeProcessingTime(work.getWorkItem().getTimers().getTimersList());
@@ -528,6 +529,11 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
           getWorkItem().getWorkToken(),
           activeReader);
       activeReader = null;
+    } else if (backlogBytes != UnboundedReader.BACKLOG_UNKNOWN && backlogBytes 
!= 1L) {
+      // If activeReader is null, we might still have backlogBytes from an 
SDF. We ignore a reported
+      // backlogBytes of 1 since older versions of the Java SDK use this value 
as a default when
+      // RestrictionTracker.getProgress() or GetSize() are not defined.
+      outputBuilder.setSourceBacklogBytes(backlogBytes);
     }
     return callbacks;
   }
@@ -726,6 +732,11 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     public BundleFinalizer bundleFinalizer() {
       return wrapped.bundleFinalizer();
     }
+
+    @Override
+    public void setBacklogBytes(double backlogBytes) {
+      wrapped.setBacklogBytes(backlogBytes);
+    }
   }
 
   /** A {@link SideInputReader} that fetches side inputs from the streaming 
worker's cache. */
@@ -856,6 +867,11 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
       userTimerInternals.persistTo(outputBuilder);
     }
 
+    @Override
+    public void setBacklogBytes(double backlogBytes) {
+      StreamingModeExecutionContext.this.backlogBytes = (long) backlogBytes;
+    }
+
     @Override
     public <W extends BoundedWindow> TimerData getNextFiredTimer(Coder<W> 
windowCoder) {
       if (cachedFiredSystemTimers == null) {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
index 4bfa6efc888..a1c7609e5af 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
@@ -431,4 +431,29 @@ public class StreamingModeExecutionContextTest {
       assertEquals(expectedEncoding, 
executionContext.getWindmillTagEncoding().getClass());
     }
   }
+
+  @Test
+  public void testSetBacklogBytes() {
+    Windmill.WorkItemCommitRequest.Builder outputBuilder =
+        Windmill.WorkItemCommitRequest.newBuilder();
+    NameContext nameContext = NameContextsForTests.nameContextForTest();
+    DataflowOperationContext operationContext =
+        executionContext.createOperationContext(nameContext);
+    StreamingModeExecutionContext.StepContext stepContext =
+        executionContext.getStepContext(operationContext);
+
+    executionContext.start(
+        "key",
+        createMockWork(
+            
Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(),
+            Watermarks.builder().setInputDataWatermark(new 
Instant(1000)).build()),
+        stateReader,
+        sideInputStateFetcher,
+        outputBuilder);
+
+    stepContext.setBacklogBytes(1234.0);
+    executionContext.flushState();
+
+    assertEquals(1234, outputBuilder.getSourceBacklogBytes());
+  }
 }
diff --git 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangeStreamErrorTest.java
 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangeStreamErrorTest.java
index f7f2eca60bb..9f09ab18c62 100644
--- 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangeStreamErrorTest.java
+++ 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangeStreamErrorTest.java
@@ -334,9 +334,10 @@ public class SpannerChangeStreamErrorTest implements 
Serializable {
     mockChangeStreamOptions();
     mockTableExists();
     mockGetWatermark(startTimestamp);
+    mockGetPartitionSize(endTimestamp, 0);
     ResultSet getPartitionResultSet = mockGetParentPartition(startTimestamp, 
endTimestamp);
-    mockchangePartitionState(startTimestamp, endTimestamp, "CREATED");
-    mockchangePartitionState(startTimestamp, endTimestamp, "SCHEDULED");
+    mockChangePartitionState(startTimestamp, endTimestamp, "CREATED");
+    mockChangePartitionState(startTimestamp, endTimestamp, "SCHEDULED");
     mockGetPartitionsAfter(
         Timestamp.ofTimeSecondsAndNanos(startTimestamp.getSeconds(), 
startTimestamp.getNanos() - 1),
         getPartitionResultSet);
@@ -497,6 +498,40 @@ public class SpannerChangeStreamErrorTest implements 
Serializable {
         StatementResult.query(getPartitionsAfterStatement, 
getPartitionResultSet));
   }
 
+  private void mockGetPartitionSize(Timestamp timestamp, long partitionSize) {
+    Statement getPartitionsAfterStatement =
+        Statement.newBuilder(
+                "SELECT COUNT(*) as count FROM my-metadata-table WHERE 
CreatedAt > @timestamp")
+            .bind("timestamp")
+            .to(Timestamp.ofTimeSecondsAndNanos(timestamp.getSeconds(), 
timestamp.getNanos()))
+            .build();
+    ResultSetMetadata metadata =
+        ResultSetMetadata.newBuilder()
+            .setRowType(
+                StructType.newBuilder()
+                    .addFields(
+                        Field.newBuilder()
+                            .setName("count")
+                            .setType(
+                                com.google.spanner.v1.Type.newBuilder()
+                                    .setCode(TypeCode.INT64)
+                                    .build())
+                            .build())
+                    .build())
+            .build();
+    ResultSet countResultSet =
+        ResultSet.newBuilder()
+            .addRows(
+                ListValue.newBuilder()
+                    .addValues(
+                        
Value.newBuilder().setStringValue(String.valueOf(partitionSize)).build())
+                    .build())
+            .setMetadata(metadata)
+            .build();
+    mockSpannerService.putStatementResult(
+        StatementResult.query(getPartitionsAfterStatement, countResultSet));
+  }
+
   private void mockGetWatermark(Timestamp watermark) {
     final String minWatermark = "min_watermark";
     // The query needs to sync with getUnfinishedMinWatermark() in 
PartitionMetadataDao file.
@@ -591,7 +626,7 @@ public class SpannerChangeStreamErrorTest implements 
Serializable {
         StatementResult.query(tableExistsStatement, tableExistsResultSet));
   }
 
-  private ResultSet mockchangePartitionState(
+  private ResultSet mockChangePartitionState(
       Timestamp startTimestamp, Timestamp after3Seconds, String state) {
     List<String> composedPartitionTokens = new ArrayList<>();
     composedPartitionTokens.add("Parent0");


Reply via email to