This is an automated email from the ASF dual-hosted git repository.

kenn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 630f32ada54 [Drain] Expose drain to dofn processElement and onTimer 
(#37825)
630f32ada54 is described below

commit 630f32ada545fd22812c428e5543e4a89c8075ca
Author: RadosÅ‚aw Stankiewicz <[email protected]>
AuthorDate: Thu Mar 12 15:36:23 2026 +0100

    [Drain] Expose drain to dofn processElement and onTimer (#37825)
---
 .../apache/beam/runners/core/SimpleDoFnRunner.java | 15 ++++++++++
 .../reflect/ByteBuddyDoFnInvokerFactory.java       | 11 +++++++
 .../beam/sdk/transforms/reflect/DoFnInvoker.java   | 15 ++++++++++
 .../beam/sdk/transforms/reflect/DoFnSignature.java | 26 +++++++++++++++++
 .../sdk/transforms/reflect/DoFnSignatures.java     | 10 +++++++
 .../construction/SplittableParDoNaiveBounded.java  |  5 ++++
 .../sdk/transforms/reflect/DoFnSignaturesTest.java | 34 ++++++++++++++++++++--
 .../apache/beam/fn/harness/FnApiDoFnRunner.java    | 10 +++++++
 8 files changed, 124 insertions(+), 2 deletions(-)

diff --git 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
index a255467fc59..74f5a4d0900 100644
--- 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
+++ 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
@@ -555,6 +555,11 @@ public class SimpleDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Out
       return timestamp();
     }
 
+    @Override
+    public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+      return elem.causedByDrain();
+    }
+
     @Override
     public String timerId(DoFn<InputT, OutputT> doFn) {
       throw new UnsupportedOperationException(
@@ -831,6 +836,11 @@ public class SimpleDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Out
       return timestamp();
     }
 
+    @Override
+    public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+      return causedByDrain;
+    }
+
     @Override
     public String timerId(DoFn<InputT, OutputT> doFn) {
       return timerId;
@@ -1119,6 +1129,11 @@ public class SimpleDoFnRunner<InputT, OutputT> 
implements DoFnRunner<InputT, Out
       return timestamp;
     }
 
+    @Override
+    public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+      throw new UnsupportedOperationException("CausedByDrain parameters are 
not supported.");
+    }
+
     @Override
     public String timerId(DoFn<InputT, OutputT> doFn) {
       throw new UnsupportedOperationException("Timer parameters are not 
supported.");
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
index 780eb0075db..54d630d92fe 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
@@ -76,6 +76,7 @@ import 
org.apache.beam.sdk.transforms.DoFn.TruncateRestriction;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.OnTimerMethod;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.BundleFinalizerParameter;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.Cases;
+import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.CausedByDrainParameter;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ElementParameter;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.FinishBundleContextParameter;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.OnTimerContextParameter;
@@ -126,6 +127,7 @@ class ByteBuddyDoFnInvokerFactory implements 
DoFnInvokerFactory {
   public static final String ELEMENT_PARAMETER_METHOD = "element";
   public static final String SCHEMA_ELEMENT_PARAMETER_METHOD = "schemaElement";
   public static final String TIMESTAMP_PARAMETER_METHOD = "timestamp";
+  public static final String CAUSED_BY_DRAIN_PARAMETER_METHOD = 
"causedByDrain";
   public static final String BUNDLE_FINALIZER_PARAMETER_METHOD = 
"bundleFinalizer";
   public static final String OUTPUT_ROW_RECEIVER_METHOD = "outputRowReceiver";
   public static final String TIME_DOMAIN_PARAMETER_METHOD = "timeDomain";
@@ -1100,6 +1102,15 @@ class ByteBuddyDoFnInvokerFactory implements 
DoFnInvokerFactory {
                         TIMESTAMP_PARAMETER_METHOD, DoFn.class)));
           }
 
+          @Override
+          public StackManipulation dispatch(CausedByDrainParameter p) {
+            return new StackManipulation.Compound(
+                pushDelegate,
+                MethodInvocation.invoke(
+                    getExtraContextFactoryMethodDescription(
+                        CAUSED_BY_DRAIN_PARAMETER_METHOD, DoFn.class)));
+          }
+
           @Override
           public StackManipulation dispatch(BundleFinalizerParameter p) {
             return 
simpleExtraContextParameter(BUNDLE_FINALIZER_PARAMETER_METHOD);
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
index 0079435700c..a615761292a 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
@@ -41,6 +41,7 @@ import 
org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.Truncate
 import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.values.CausedByDrain;
 import org.apache.beam.sdk.values.Row;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import org.checkerframework.checker.nullness.qual.Nullable;
@@ -217,6 +218,9 @@ public interface DoFnInvoker<InputT, OutputT> {
     /** Provide a reference to the input element timestamp. */
     Instant timestamp(DoFn<InputT, OutputT> doFn);
 
+    /** Provide a reference to the caused by drain. */
+    CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn);
+
     /** Provide a reference to the time domain for a timer firing. */
     TimeDomain timeDomain(DoFn<InputT, OutputT> doFn);
 
@@ -325,6 +329,12 @@ public interface DoFnInvoker<InputT, OutputT> {
           String.format("Timestamp unsupported in %s", getErrorContext()));
     }
 
+    @Override
+    public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+      throw new UnsupportedOperationException(
+          String.format("CausedByDrain unsupported in %s", getErrorContext()));
+    }
+
     @Override
     public String timerId(DoFn<InputT, OutputT> doFn) {
       throw new UnsupportedOperationException(
@@ -514,6 +524,11 @@ public interface DoFnInvoker<InputT, OutputT> {
       return delegate.timestamp(doFn);
     }
 
+    @Override
+    public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+      return delegate.causedByDrain(doFn);
+    }
+
     @Override
     public TimeDomain timeDomain(DoFn<InputT, OutputT> doFn) {
       return delegate.timeDomain(doFn);
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
index 8f254642f08..af0353c902a 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
@@ -342,6 +342,8 @@ public abstract class DoFnSignature {
         return cases.dispatch((TimerIdParameter) this);
       } else if (this instanceof BundleFinalizerParameter) {
         return cases.dispatch((BundleFinalizerParameter) this);
+      } else if (this instanceof CausedByDrainParameter) {
+        return cases.dispatch((CausedByDrainParameter) this);
       } else if (this instanceof KeyParameter) {
         return cases.dispatch((KeyParameter) this);
       } else {
@@ -400,6 +402,8 @@ public abstract class DoFnSignature {
 
       ResultT dispatch(BundleFinalizerParameter p);
 
+      ResultT dispatch(CausedByDrainParameter p);
+
       ResultT dispatch(KeyParameter p);
 
       /** A base class for a visitor with a default method for cases it is not 
interested in. */
@@ -497,6 +501,11 @@ public abstract class DoFnSignature {
           return dispatchDefault(p);
         }
 
+        @Override
+        public ResultT dispatch(CausedByDrainParameter p) {
+          return dispatchDefault(p);
+        }
+
         @Override
         public ResultT dispatch(StateParameter p) {
           return dispatchDefault(p);
@@ -552,6 +561,8 @@ public abstract class DoFnSignature {
         new AutoValue_DoFnSignature_Parameter_PipelineOptionsParameter();
     private static final BundleFinalizerParameter BUNDLE_FINALIZER_PARAMETER =
         new AutoValue_DoFnSignature_Parameter_BundleFinalizerParameter();
+    private static final CausedByDrainParameter CAUSED_BY_DRAIN_PARAMETER =
+        new AutoValue_DoFnSignature_Parameter_CausedByDrainParameter();
     private static final OnWindowExpirationContextParameter 
ON_WINDOW_EXPIRATION_CONTEXT_PARAMETER =
         new 
AutoValue_DoFnSignature_Parameter_OnWindowExpirationContextParameter();
 
@@ -575,6 +586,11 @@ public abstract class DoFnSignature {
       return BUNDLE_FINALIZER_PARAMETER;
     }
 
+    /** Returns a {@link CausedByDrainParameter}. */
+    public static CausedByDrainParameter causedByDrainParameter() {
+      return CAUSED_BY_DRAIN_PARAMETER;
+    }
+
     public static ElementParameter elementParameter(TypeDescriptor<?> 
elementT) {
       return new AutoValue_DoFnSignature_Parameter_ElementParameter(elementT);
     }
@@ -727,6 +743,16 @@ public abstract class DoFnSignature {
       BundleFinalizerParameter() {}
     }
 
+    /**
+     * Descriptor for a {@link Parameter} of type {@link 
org.apache.beam.sdk.values.CausedByDrain}.
+     *
+     * <p>All such descriptors are equal.
+     */
+    @AutoValue
+    public abstract static class CausedByDrainParameter extends Parameter {
+      CausedByDrainParameter() {}
+    }
+
     /**
      * Descriptor for a {@link Parameter} of type {@link DoFn.Element}.
      *
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
index c39edccd58f..3dcf7ff1f9d 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
@@ -91,6 +91,7 @@ import 
org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.util.common.ReflectHelpers;
+import org.apache.beam.sdk.values.CausedByDrain;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.Row;
@@ -139,6 +140,7 @@ public class DoFnSignatures {
               Parameter.StateParameter.class,
               Parameter.SideInputParameter.class,
               Parameter.TimerFamilyParameter.class,
+              Parameter.CausedByDrainParameter.class,
               Parameter.BundleFinalizerParameter.class);
 
   private static final ImmutableList<Class<? extends Parameter>>
@@ -155,6 +157,7 @@ public class DoFnSignatures {
               Parameter.RestrictionTrackerParameter.class,
               Parameter.WatermarkEstimatorParameter.class,
               Parameter.SideInputParameter.class,
+              Parameter.CausedByDrainParameter.class,
               Parameter.BundleFinalizerParameter.class);
 
   private static final ImmutableList<Class<? extends Parameter>> 
ALLOWED_SETUP_PARAMETERS =
@@ -185,6 +188,7 @@ public class DoFnSignatures {
           Parameter.StateParameter.class,
           Parameter.TimerFamilyParameter.class,
           Parameter.TimerIdParameter.class,
+          Parameter.CausedByDrainParameter.class,
           Parameter.KeyParameter.class);
 
   private static final ImmutableList<Class<? extends Parameter>>
@@ -201,6 +205,7 @@ public class DoFnSignatures {
               Parameter.StateParameter.class,
               Parameter.TimerFamilyParameter.class,
               Parameter.TimerIdParameter.class,
+              Parameter.CausedByDrainParameter.class,
               Parameter.KeyParameter.class);
 
   private static final Collection<Class<? extends Parameter>>
@@ -1357,6 +1362,11 @@ public class DoFnSignatures {
       return Parameter.keyT(paramT);
     } else if (rawType.equals(TimeDomain.class)) {
       return Parameter.timeDomainParameter();
+    } else if (CausedByDrain.class.isAssignableFrom(rawType)) {
+      methodErrors.checkArgument(
+          rawType.equals(CausedByDrain.class),
+          "CausedByDrain argument must have type 
org.apache.beam.sdk.values.CausedByDrain.");
+      return Parameter.causedByDrainParameter();
     } else if (hasAnnotation(DoFn.SideInput.class, param.getAnnotations())) {
       String sideInputId = getSideInputId(param.getAnnotations());
       paramErrors.checkArgument(
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/SplittableParDoNaiveBounded.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/SplittableParDoNaiveBounded.java
index a22d3378cfd..6d058b3b6ad 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/SplittableParDoNaiveBounded.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/SplittableParDoNaiveBounded.java
@@ -543,6 +543,11 @@ public class SplittableParDoNaiveBounded {
         return outerContext.timestamp();
       }
 
+      @Override
+      public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+        return outerContext.causedByDrain();
+      }
+
       @Override
       public String timerId(DoFn<InputT, OutputT> doFn) {
         throw new UnsupportedOperationException();
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
index de4a622e03d..3369e18519b 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
@@ -56,6 +56,7 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.BundleFinalizerParameter;
+import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.CausedByDrainParameter;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ElementParameter;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.FinishBundleContextParameter;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.OutputReceiverParameter;
@@ -78,6 +79,7 @@ import 
org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.values.CausedByDrain;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.Row;
 import org.apache.beam.sdk.values.TypeDescriptor;
@@ -130,10 +132,11 @@ public class DoFnSignaturesTest {
                   PipelineOptions options,
                   @SideInput("tag1") String input1,
                   @SideInput("tag2") Integer input2,
-                  BundleFinalizer bundleFinalizer) {}
+                  BundleFinalizer bundleFinalizer,
+                  CausedByDrain causedByDrain) {}
             }.getClass());
 
-    assertThat(sig.processElement().extraParameters().size(), equalTo(9));
+    assertThat(sig.processElement().extraParameters().size(), equalTo(10));
     assertThat(sig.processElement().extraParameters().get(0), 
instanceOf(ElementParameter.class));
     assertThat(sig.processElement().extraParameters().get(1), 
instanceOf(TimestampParameter.class));
     assertThat(sig.processElement().extraParameters().get(2), 
instanceOf(WindowParameter.class));
@@ -146,6 +149,8 @@ public class DoFnSignaturesTest {
     assertThat(sig.processElement().extraParameters().get(7), 
instanceOf(SideInputParameter.class));
     assertThat(
         sig.processElement().extraParameters().get(8), 
instanceOf(BundleFinalizerParameter.class));
+    assertThat(
+        sig.processElement().extraParameters().get(9), 
instanceOf(CausedByDrainParameter.class));
   }
 
   @Test
@@ -585,6 +590,31 @@ public class DoFnSignaturesTest {
         instanceOf(WindowParameter.class));
   }
 
+  @Test
+  public void testCausedByDrainOnTimer() throws Exception {
+    final String timerId = "some-timer-id";
+    final String timerDeclarationId = TimerDeclaration.PREFIX + timerId;
+
+    DoFnSignature sig =
+        DoFnSignatures.getSignature(
+            new DoFn<String, String>() {
+
+              @TimerId(timerId)
+              private final TimerSpec myfield1 = 
TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+              @ProcessElement
+              public void process(ProcessContext c) {}
+
+              @OnTimer(timerId)
+              public void onTimer(CausedByDrain causedByDrain) {}
+            }.getClass());
+
+    
assertThat(sig.onTimerMethods().get(timerDeclarationId).extraParameters().size(),
 equalTo(1));
+    assertThat(
+        sig.onTimerMethods().get(timerDeclarationId).extraParameters().get(0),
+        instanceOf(CausedByDrainParameter.class));
+  }
+
   @Test
   public void testAllParamsOnTimer() throws Exception {
     final String timerId = "some-timer-id";
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 3893c0f405e..1dfa336e35f 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -1804,6 +1804,11 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
       outputTo(consumer, WindowedValues.of(output, timestamp, windows, 
paneInfo));
     }
 
+    @Override
+    public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+      return currentElement.causedByDrain();
+    }
+
     @Override
     public State state(String stateId, boolean alwaysFetched) {
       StateDeclaration stateDeclaration = 
doFnSignature.stateDeclarations().get(stateId);
@@ -1946,6 +1951,11 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
     public CausedByDrain causedByDrain() {
       return currentElement.causedByDrain();
     }
+
+    @Override
+    public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+      return currentElement.causedByDrain();
+    }
   }
 
   /** Provides base arguments for a {@link DoFnInvoker} for a non-window 
observing method. */

Reply via email to