This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b1aba9d24e8 #20812 handle @RequiresStableInput in portable flink 
(#22889)
b1aba9d24e8 is described below

commit b1aba9d24e8861e4c3122b004ce9c0aa2b5dac8c
Author: Jan Lukavský <je...@seznam.cz>
AuthorDate: Tue Jan 24 00:42:41 2023 +0100

    #20812 handle @RequiresStableInput in portable flink (#22889)
    
    * Handle @RequiresStableInput in portable flink (#20812)
    
    * [runners-flink] Remove unnnecessary dependency on flink-annotations
    
    * Fix @RequiresStableInput for portable Flink (#20812)
    
     Fix FlinkRequiresStableInputTest flakiness (#21333)
    
    * Flink: Tests for stateful stable dofns (#20812)
    
    * Enable commit for kafka flink portable test
    
    * Apply suggestions from code review
    
    Co-authored-by: Lukasz Cwik <lc...@google.com>
    
    * Add callback to BufferingDoFnRunner for flushing SDK harness results
    
    * revert changes in website
    
    Co-authored-by: Lukasz Cwik <lc...@google.com>
---
 runners/flink/flink_runner.gradle                  |   1 -
 .../utils/FlinkPortableRunnerUtils.java            |  15 ++
 .../runners/flink/translation/utils/Locker.java    |  40 +++
 .../wrappers/streaming/DoFnOperator.java           | 100 +++++---
 .../streaming/ExecutableStageDoFnOperator.java     |  79 +++---
 .../streaming/stableinput/BufferedElements.java    |   6 +
 .../streaming/stableinput/BufferingDoFnRunner.java | 127 ++++++++--
 .../flink/FlinkRequiresStableInputTest.java        | 279 ++++++++++++---------
 .../apache/beam/runners/flink/FlinkRunnerTest.java |   4 +-
 .../beam/runners/flink/PortableExecutionTest.java  |   1 +
 .../wrappers/streaming/DoFnOperatorTest.java       |   5 +-
 .../streaming/ExecutableStageDoFnOperatorTest.java |  80 +++++-
 .../org/apache/beam/sdk/RequiresStableInputIT.java |   6 +-
 .../runners/portability/flink_runner_test.py       |  20 +-
 14 files changed, 537 insertions(+), 226 deletions(-)

diff --git a/runners/flink/flink_runner.gradle 
b/runners/flink/flink_runner.gradle
index 6cc50d2525b..ccd4f75d3b7 100644
--- a/runners/flink/flink_runner.gradle
+++ b/runners/flink/flink_runner.gradle
@@ -208,7 +208,6 @@ dependencies {
   implementation project(":sdks:java:fn-execution")
   implementation library.java.jackson_databind
   runtimeOnly library.java.jackson_jaxb_annotations
-  implementation "org.apache.flink:flink-annotations:$flink_version"
   examplesJavaIntegrationTest project(project.path)
   examplesJavaIntegrationTest project(":examples:java")
   examplesJavaIntegrationTest project(path: ":examples:java", configuration: 
"testRuntimeMigration")
diff --git 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/FlinkPortableRunnerUtils.java
 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/FlinkPortableRunnerUtils.java
index d50da94ed3b..342fc558d5a 100644
--- 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/FlinkPortableRunnerUtils.java
+++ 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/FlinkPortableRunnerUtils.java
@@ -53,6 +53,21 @@ public class FlinkPortableRunnerUtils {
     return requiresTimeSortedInput;
   }
 
+  public static boolean requiresStableInput(RunnerApi.ExecutableStagePayload 
payload) {
+
+    return payload.getComponents().getTransformsMap().values().stream()
+        .filter(t -> 
t.getSpec().getUrn().equals(PTransformTranslation.PAR_DO_TRANSFORM_URN))
+        .anyMatch(
+            t -> {
+              try {
+                return 
RunnerApi.ParDoPayload.parseFrom(t.getSpec().getPayload())
+                    .getRequiresStableInput();
+              } catch (InvalidProtocolBufferException e) {
+                throw new RuntimeException(e);
+              }
+            });
+  }
+
   /** Do not construct. */
   private FlinkPortableRunnerUtils() {}
 }
diff --git 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/Locker.java
 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/Locker.java
new file mode 100644
index 00000000000..f3a5e3885a1
--- /dev/null
+++ 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/Locker.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.translation.utils;
+
+import java.util.concurrent.locks.Lock;
+
+public class Locker implements AutoCloseable {
+
+  public static Locker locked(Lock lock) {
+    Locker locker = new Locker(lock);
+    lock.lock();
+    return locker;
+  }
+
+  private final Lock lock;
+
+  Locker(Lock lock) {
+    this.lock = lock;
+  }
+
+  @Override
+  public void close() {
+    lock.unlock();
+  }
+}
diff --git 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
index b0f6cf22e9b..13c4e4e0a99 100644
--- 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
+++ 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
@@ -88,11 +88,11 @@ import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.WindowingStrategy;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.MapState;
@@ -177,15 +177,15 @@ public class DoFnOperator<InputT, OutputT>
 
   protected final String stepName;
 
-  private final Coder<WindowedValue<InputT>> windowedInputCoder;
+  final Coder<WindowedValue<InputT>> windowedInputCoder;
 
-  private final Map<TupleTag<?>, Coder<?>> outputCoders;
+  final Map<TupleTag<?>, Coder<?>> outputCoders;
 
-  protected final Coder<?> keyCoder;
+  final Coder<?> keyCoder;
 
   final KeySelector<WindowedValue<InputT>, ?> keySelector;
 
-  private final TimerInternals.TimerDataCoderV2 timerCoder;
+  final TimerInternals.TimerDataCoderV2 timerCoder;
 
   /** Max number of elements to include in a bundle. */
   private final long maxBundleSize;
@@ -197,7 +197,9 @@ public class DoFnOperator<InputT, OutputT>
   private final Map<String, PCollectionView<?>> sideInputMapping;
 
   /** If true, we must process elements only after a checkpoint is finished. */
-  private final boolean requiresStableInput;
+  final boolean requiresStableInput;
+
+  final int numConcurrentCheckpoints;
 
   private final boolean usesOnWindowExpiration;
 
@@ -301,10 +303,8 @@ public class DoFnOperator<InputT, OutputT>
     this.doFnSchemaInformation = doFnSchemaInformation;
     this.sideInputMapping = sideInputMapping;
 
-    this.requiresStableInput =
-        // WindowDoFnOperator does not use a DoFn
-        doFn != null
-            && 
DoFnSignatures.getSignature(doFn.getClass()).processElement().requiresStableInput();
+    this.requiresStableInput = isRequiresStableInput(doFn);
+
     this.usesOnWindowExpiration =
         doFn != null && 
DoFnSignatures.getSignature(doFn.getClass()).onWindowExpiration() != null;
 
@@ -323,9 +323,22 @@ public class DoFnOperator<InputT, OutputT>
               + Math.max(0, flinkOptions.getMinPauseBetweenCheckpoints()));
     }
 
+    this.numConcurrentCheckpoints = flinkOptions.getNumConcurrentCheckpoints();
+
     this.finishBundleBeforeCheckpointing = 
flinkOptions.getFinishBundleBeforeCheckpointing();
   }
 
+  private boolean isRequiresStableInput(DoFn<InputT, OutputT> doFn) {
+    // WindowDoFnOperator does not use a DoFn
+    return doFn != null
+        && 
DoFnSignatures.getSignature(doFn.getClass()).processElement().requiresStableInput();
+  }
+
+  @VisibleForTesting
+  boolean getRequiresStableInput() {
+    return requiresStableInput;
+  }
+
   // allow overriding this in WindowDoFnOperator because this one dynamically 
creates
   // the DoFn
   protected DoFn<InputT, OutputT> getDoFn() {
@@ -490,21 +503,8 @@ public class DoFnOperator<InputT, OutputT>
             doFnSchemaInformation,
             sideInputMapping);
 
-    if (requiresStableInput) {
-      // put this in front of the root FnRunner before any additional wrappers
-      doFnRunner =
-          bufferingDoFnRunner =
-              BufferingDoFnRunner.create(
-                  doFnRunner,
-                  "stable-input-buffer",
-                  windowedInputCoder,
-                  windowingStrategy.getWindowFn().windowCoder(),
-                  getOperatorStateBackend(),
-                  getKeyedStateBackend(),
-                  options.getNumConcurrentCheckpoints(),
-                  serializedOptions);
-    }
-    doFnRunner = createWrappingDoFnRunner(doFnRunner, stepContext);
+    doFnRunner =
+        createBufferingDoFnRunnerIfNeeded(createWrappingDoFnRunner(doFnRunner, 
stepContext));
     earlyBindStateIfNeeded();
 
     if (!options.getDisableMetrics()) {
@@ -545,6 +545,36 @@ public class DoFnOperator<InputT, OutputT>
     pendingFinalizations = new LinkedHashMap<>();
   }
 
+  DoFnRunner<InputT, OutputT> createBufferingDoFnRunnerIfNeeded(
+      DoFnRunner<InputT, OutputT> wrappedRunner) throws Exception {
+
+    if (requiresStableInput) {
+      // put this in front of the root FnRunner before any additional wrappers
+      return this.bufferingDoFnRunner =
+          BufferingDoFnRunner.create(
+              wrappedRunner,
+              "stable-input-buffer",
+              windowedInputCoder,
+              windowingStrategy.getWindowFn().windowCoder(),
+              getOperatorStateBackend(),
+              getBufferingKeyedStateBackend(),
+              numConcurrentCheckpoints,
+              serializedOptions);
+    }
+    return wrappedRunner;
+  }
+
+  /**
+   * Retrieve a keyed state backend that should be used to buffer elements for 
{@link @{code @}
+   * RequiresStableInput} functionality. By default this is the default keyed 
backend, but can be
+   * override in @{link ExecutableStageDoFnOperator}.
+   *
+   * @return the keyed backend to use for element buffering
+   */
+  <K> @Nullable KeyedStateBackend<K> getBufferingKeyedStateBackend() {
+    return getKeyedStateBackend();
+  }
+
   private void earlyBindStateIfNeeded() throws IllegalArgumentException, 
IllegalAccessException {
     if (keyCoder != null) {
       if (doFn != null) {
@@ -598,7 +628,9 @@ public class DoFnOperator<InputT, OutputT>
     }
     if (currentOutputWatermark < Long.MAX_VALUE) {
       throw new RuntimeException(
-          "There are still watermark holds. Watermark held at " + 
currentOutputWatermark);
+          String.format(
+              "There are still watermark holds left when terminating operator 
%s Watermark held %d",
+              getOperatorName(), currentOutputWatermark));
     }
 
     // sanity check: these should have been flushed out by +Inf watermarks
@@ -617,7 +649,12 @@ public class DoFnOperator<InputT, OutputT>
 
   public long getEffectiveInputWatermark() {
     // hold back by the pushed back values waiting for side inputs
-    return Math.min(pushedBackWatermark, currentInputWatermark);
+    long combinedPushedBackWatermark = pushedBackWatermark;
+    if (requiresStableInput) {
+      combinedPushedBackWatermark =
+          Math.min(combinedPushedBackWatermark, 
bufferingDoFnRunner.getOutputWatermarkHold());
+    }
+    return Math.min(combinedPushedBackWatermark, currentInputWatermark);
   }
 
   public long getCurrentOutputWatermark() {
@@ -760,8 +797,8 @@ public class DoFnOperator<InputT, OutputT>
   }
 
   /**
-   * Allows to apply a hold to the output watermark before it is send out. By 
default, just passes
-   * the potential output watermark through which will make it the new output 
watermark.
+   * Allows to apply a hold to the output watermark before it is sent out. 
Used to apply hold on
+   * output watermark for delayed (asynchronous or buffered) processing.
    *
    * @param currentOutputWatermark the current output watermark
    * @param potentialOutputWatermark The potential new output watermark which 
can be adjusted, if
@@ -797,7 +834,7 @@ public class DoFnOperator<InputT, OutputT>
         return;
       }
 
-      LOG.debug("Emitting watermark {}", watermark);
+      LOG.debug("Emitting watermark {} from {}", watermark, getOperatorName());
       currentOutputWatermark = watermark;
       output.emitWatermark(new Watermark(watermark));
 
@@ -902,7 +939,7 @@ public class DoFnOperator<InputT, OutputT>
     timeService.registerTimer(timeService.getCurrentProcessingTime(), 
callback);
   }
 
-  private void updateOutputWatermark() {
+  void updateOutputWatermark() {
     try {
       processInputWatermark(false);
     } catch (Exception ex) {
@@ -1005,6 +1042,7 @@ public class DoFnOperator<InputT, OutputT>
       // We can now release all buffered data which was held back for
       // @RequiresStableInput guarantees.
       bufferingDoFnRunner.checkpointCompleted(checkpointId);
+      updateOutputWatermark();
     }
 
     List<InMemoryBundleFinalizer.Finalization> finalizations =
diff --git 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
index 2f967a15810..2df14a8bfa7 100644
--- 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
+++ 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
@@ -18,6 +18,7 @@
 package org.apache.beam.runners.flink.translation.wrappers.streaming;
 
 import static 
org.apache.beam.runners.core.StatefulDoFnRunner.TimeInternalsCleanupTimer.GC_TIMER_ID;
+import static 
org.apache.beam.runners.flink.translation.utils.FlinkPortableRunnerUtils.requiresStableInput;
 import static 
org.apache.beam.runners.flink.translation.utils.FlinkPortableRunnerUtils.requiresTimeSortedInput;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
@@ -64,6 +65,8 @@ import 
org.apache.beam.runners.core.construction.graph.ExecutableStage;
 import org.apache.beam.runners.core.construction.graph.UserStateReference;
 import 
org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContextFactory;
 import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer;
+import org.apache.beam.runners.flink.translation.utils.Locker;
+import 
org.apache.beam.runners.flink.translation.wrappers.streaming.stableinput.BufferingDoFnRunner;
 import 
org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals;
 import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler;
 import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandlers;
@@ -171,7 +174,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> 
extends DoFnOperator<I
   private transient long minEventTimeTimerTimestampInCurrentBundle;
 
   /** The input watermark before the current bundle started. */
-  private transient long inputWatermarkBeforeBundleStart;
+  private long inputWatermarkBeforeBundleStart = 
BoundedWindow.TIMESTAMP_MIN_VALUE.getMillis();
 
   /** Flag indicating whether the operator has been closed. */
   private transient boolean closed;
@@ -196,7 +199,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> 
extends DoFnOperator<I
       Coder keyCoder,
       KeySelector<WindowedValue<InputT>, ?> keySelector) {
     super(
-        new NoOpDoFn(),
+        requiresStableInput(payload) ? new StableNoOpDoFn() : new NoOpDoFn(),
         stepName,
         windowedInputCoder,
         outputCoders,
@@ -228,6 +231,13 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> 
extends DoFnOperator<I
         windowedInputCoder);
   }
 
+  @Override
+  <K> @Nullable KeyedStateBackend<K> getBufferingKeyedStateBackend() {
+    // do not use keyed backend for buffering if we do not process stateful 
DoFn
+    // ExecutableStage uses keyed backend by default
+    return isStateful ? super.getKeyedStateBackend() : null;
+  }
+
   @Override
   protected Lock getLockToAcquireForStateAccessDuringBundles() {
     return stateBackendLock;
@@ -280,8 +290,8 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> 
extends DoFnOperator<I
 
   @Override
   public final void notifyCheckpointComplete(long checkpointId) throws 
Exception {
-    finalizationHandler.finalizeAllOutstandingBundles();
     super.notifyCheckpointComplete(checkpointId);
+    finalizationHandler.finalizeAllOutstandingBundles();
   }
 
   private BundleCheckpointHandler getBundleCheckpointHandler(boolean hasSDF) {
@@ -745,6 +755,32 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> 
extends DoFnOperator<I
     sideInputHandler.addSideInputValue(sideInput, 
value.withValue(value.getValue().getValue()));
   }
 
+  @Override
+  DoFnRunner<InputT, OutputT> createBufferingDoFnRunnerIfNeeded(
+      DoFnRunner<InputT, OutputT> wrappedRunner) throws Exception {
+
+    if (requiresStableInput) {
+      // put this in front of the root FnRunner before any additional wrappers
+      KeyedStateBackend<Object> keyedBufferingBackend = 
getBufferingKeyedStateBackend();
+      return this.bufferingDoFnRunner =
+          BufferingDoFnRunner.create(
+              wrappedRunner,
+              "stable-input-buffer",
+              windowedInputCoder,
+              windowingStrategy.getWindowFn().windowCoder(),
+              getOperatorStateBackend(),
+              keyedBufferingBackend,
+              numConcurrentCheckpoints,
+              serializedOptions,
+              keyedBufferingBackend != null ? () -> 
Locker.locked(stateBackendLock) : null,
+              keyedBufferingBackend != null
+                  ? input -> FlinkKeyUtils.encodeKey(((KV) input).getKey(), 
(Coder) keyCoder)
+                  : null,
+              sdkHarnessRunner::emitResults);
+    }
+    return wrappedRunner;
+  }
+
   @Override
   protected DoFnRunner<InputT, OutputT> createWrappingDoFnRunner(
       DoFnRunner<InputT, OutputT> wrappedRunner, StepContext stepContext) {
@@ -814,6 +850,8 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> 
extends DoFnOperator<I
     // gives better throughput due to the bundle not getting cut on
     // every watermark. So we have implemented 2) below.
     //
+    potentialOutputWatermark =
+        super.applyOutputWatermarkHold(currentOutputWatermark, 
potentialOutputWatermark);
     if (sdkHarnessRunner.isBundleInProgress()) {
       if (minEventTimeTimerTimestampInLastBundle < Long.MAX_VALUE) {
         // We can safely advance the watermark to before the last bundle's 
minimum event timer
@@ -1055,7 +1093,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> 
extends DoFnOperator<I
   }
 
   private DoFnRunner<InputT, OutputT> ensureStateDoFnRunner(
-      SdkHarnessDoFnRunner<InputT, OutputT> sdkHarnessRunner,
+      DoFnRunner<InputT, OutputT> sdkHarnessRunner,
       RunnerApi.ExecutableStagePayload payload,
       StepContext stepContext) {
 
@@ -1097,17 +1135,6 @@ public class ExecutableStageDoFnOperator<InputT, 
OutputT> extends DoFnOperator<I
         stateCleaner,
         requiresTimeSortedInput(payload, true)) {
 
-      @Override
-      public void processElement(WindowedValue<InputT> input) {
-        try (Locker locker = Locker.locked(stateBackendLock)) {
-          @SuppressWarnings({"unchecked", "rawtypes"})
-          final ByteBuffer key =
-              FlinkKeyUtils.encodeKey(((KV) input.getValue()).getKey(), 
(Coder) keyCoder);
-          getKeyedStateBackend().setCurrentKey(key);
-          super.processElement(input);
-        }
-      }
-
       @Override
       public void finishBundle() {
         // Before cleaning up state, first finish bundle for all underlying 
DoFnRunners
@@ -1275,23 +1302,9 @@ public class ExecutableStageDoFnOperator<InputT, 
OutputT> extends DoFnOperator<I
     public void doNothing(ProcessContext context) {}
   }
 
-  private static class Locker implements AutoCloseable {
-
-    public static Locker locked(Lock lock) {
-      Locker locker = new Locker(lock);
-      lock.lock();
-      return locker;
-    }
-
-    private final Lock lock;
-
-    Locker(Lock lock) {
-      this.lock = lock;
-    }
-
-    @Override
-    public void close() {
-      lock.unlock();
-    }
+  private static class StableNoOpDoFn<InputT, OutputT> extends DoFn<InputT, 
OutputT> {
+    @RequiresStableInput
+    @ProcessElement
+    public void doNothing(ProcessContext context) {}
   }
 }
diff --git 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferedElements.java
 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferedElements.java
index 772b811df1e..ad3e37fa5a7 100644
--- 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferedElements.java
+++ 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferedElements.java
@@ -29,6 +29,7 @@ import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.state.TimeDomain;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Instant;
 
@@ -66,6 +67,11 @@ class BufferedElements {
     public int hashCode() {
       return Objects.hash(element);
     }
+
+    @Override
+    public String toString() {
+      return MoreObjects.toStringHelper(this).add("element", 
element).toString();
+    }
   }
 
   static final class Timer<KeyT> implements BufferedElement {
diff --git 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferingDoFnRunner.java
 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferingDoFnRunner.java
index be6ac2838a6..2a9b176796e 100644
--- 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferingDoFnRunner.java
+++ 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferingDoFnRunner.java
@@ -22,9 +22,13 @@ import java.util.Collections;
 import java.util.Comparator;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Optional;
+import java.util.function.Function;
+import java.util.function.Supplier;
 import org.apache.beam.runners.core.DoFnRunner;
 import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
 import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer;
+import org.apache.beam.runners.flink.translation.utils.Locker;
 import org.apache.beam.sdk.state.TimeDomain;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -62,15 +66,47 @@ public class BufferingDoFnRunner<InputT, OutputT> 
implements DoFnRunner<InputT,
       int maxConcurrentCheckpoints,
       SerializablePipelineOptions pipelineOptions)
       throws Exception {
+
     return new BufferingDoFnRunner<>(
         doFnRunner,
         stateName,
         windowedInputCoder,
         windowCoder,
         operatorStateBackend,
+        maxConcurrentCheckpoints,
+        pipelineOptions,
         keyedStateBackend,
+        null,
+        null,
+        null);
+  }
+
+  public static <InputT, OutputT> BufferingDoFnRunner<InputT, OutputT> create(
+      DoFnRunner<InputT, OutputT> doFnRunner,
+      String stateName,
+      org.apache.beam.sdk.coders.Coder windowedInputCoder,
+      org.apache.beam.sdk.coders.Coder windowCoder,
+      OperatorStateBackend operatorStateBackend,
+      @Nullable KeyedStateBackend<Object> keyedStateBackend,
+      int maxConcurrentCheckpoints,
+      SerializablePipelineOptions pipelineOptions,
+      @Nullable Supplier<Locker> locker,
+      @Nullable Function<InputT, Object> keySelector,
+      @Nullable Runnable finishBundleCallback)
+      throws Exception {
+
+    return new BufferingDoFnRunner<>(
+        doFnRunner,
+        stateName,
+        windowedInputCoder,
+        windowCoder,
+        operatorStateBackend,
         maxConcurrentCheckpoints,
-        pipelineOptions);
+        pipelineOptions,
+        keyedStateBackend,
+        locker,
+        keyedStateBackend != null ? keySelector : null,
+        finishBundleCallback);
   }
 
   /** The underlying DoFnRunner that any buffered data will be handed over to 
eventually. */
@@ -85,6 +121,21 @@ public class BufferingDoFnRunner<InputT, OutputT> 
implements DoFnRunner<InputT,
   int currentStateIndex;
   /** The current handler used for buffering. */
   private BufferingElementsHandler currentBufferingElementsHandler;
+  /** Minimum timestamp of all buffered elements. */
+  private volatile long minBufferedElementTimestamp;
+  /** The associated keyed state backend. */
+  private final @Nullable KeyedStateBackend keyedStateBackend;
+  /**
+   * Locker that must be held (if present) before buffering an element. If 
non-null, we must
+   * manually set a key to the state backend.
+   */
+  private final @Nullable Supplier<Locker> locker;
+  /**
+   * A selector of key. When non-null, this must be set to the keyed state 
beckend before buffering.
+   */
+  private final @Nullable Function<InputT, Object> keySelector;
+  /** Callable to notify about possibility to flush bundle. */
+  private final @Nullable Runnable finishBundleCallback;
 
   private BufferingDoFnRunner(
       DoFnRunner<InputT, OutputT> underlying,
@@ -92,10 +143,14 @@ public class BufferingDoFnRunner<InputT, OutputT> 
implements DoFnRunner<InputT,
       org.apache.beam.sdk.coders.Coder inputCoder,
       org.apache.beam.sdk.coders.Coder windowCoder,
       OperatorStateBackend operatorStateBackend,
-      @Nullable KeyedStateBackend keyedStateBackend,
       int maxConcurrentCheckpoints,
-      SerializablePipelineOptions pipelineOptions)
+      SerializablePipelineOptions pipelineOptions,
+      @Nullable KeyedStateBackend keyedStateBackend,
+      @Nullable Supplier<Locker> locker,
+      @Nullable Function<InputT, Object> keySelector,
+      @Nullable Runnable finishBundleCallback)
       throws Exception {
+
     Preconditions.checkArgument(
         maxConcurrentCheckpoints > 0 && maxConcurrentCheckpoints < 
Short.MAX_VALUE,
         "Maximum number of concurrent checkpoints not within the bounds of 0 
and %s",
@@ -122,6 +177,14 @@ public class BufferingDoFnRunner<InputT, OutputT> 
implements DoFnRunner<InputT,
     this.numCheckpointBuffers = initializeState(maxConcurrentCheckpoints);
     this.currentBufferingElementsHandler =
         bufferingElementsHandlerFactory.get(rotateAndGetStateIndex());
+    this.keyedStateBackend = keyedStateBackend;
+    this.locker = locker;
+    this.keySelector = keySelector;
+    this.finishBundleCallback = finishBundleCallback;
+
+    Preconditions.checkArgument(
+        keySelector == null || keyedStateBackend != null,
+        "keySelector must be null for null keyed state backend");
   }
 
   /**
@@ -140,6 +203,7 @@ public class BufferingDoFnRunner<InputT, OutputT> 
implements DoFnRunner<InputT,
       lastUsedIndex = pendingSnapshots.get(pendingSnapshots.size() - 
1).internalId;
     }
     this.currentStateIndex = lastUsedIndex;
+    this.minBufferedElementTimestamp = Long.MAX_VALUE;
     // If a previous run had a higher number of concurrent checkpoints we need 
to use this number to
     // not break the buffering/flushing logic.
     return Math.max(maxConcurrentCheckpoints, maxIndex) + 1;
@@ -152,7 +216,14 @@ public class BufferingDoFnRunner<InputT, OutputT> 
implements DoFnRunner<InputT,
 
   @Override
   public void processElement(WindowedValue<InputT> elem) {
-    currentBufferingElementsHandler.buffer(new BufferedElements.Element(elem));
+    minBufferedElementTimestamp =
+        Math.min(elem.getTimestamp().getMillis(), minBufferedElementTimestamp);
+    try (Locker lock = locker != null ? locker.get() : null) {
+      if (keySelector != null) {
+        keyedStateBackend.setCurrentKey(keySelector.apply(elem.getValue()));
+      }
+      currentBufferingElementsHandler.buffer(new 
BufferedElements.Element(elem));
+    }
   }
 
   @Override
@@ -164,14 +235,22 @@ public class BufferingDoFnRunner<InputT, OutputT> 
implements DoFnRunner<InputT,
       Instant timestamp,
       Instant outputTimestamp,
       TimeDomain timeDomain) {
-    currentBufferingElementsHandler.buffer(
-        new BufferedElements.Timer<>(
-            timerId, timerFamilyId, key, window, timestamp, outputTimestamp, 
timeDomain));
+
+    minBufferedElementTimestamp =
+        Math.min(outputTimestamp.getMillis(), minBufferedElementTimestamp);
+    try (Locker lock = locker != null ? locker.get() : null) {
+      if (keySelector != null) {
+        keyedStateBackend.setCurrentKey(key);
+      }
+      currentBufferingElementsHandler.buffer(
+          new BufferedElements.Timer<>(
+              timerId, timerFamilyId, key, window, timestamp, outputTimestamp, 
timeDomain));
+    }
   }
 
   @Override
   public void finishBundle() {
-    // Do not finish a bundle, finish it later when emitting elements
+    Optional.ofNullable(finishBundleCallback).ifPresent(Runnable::run);
   }
 
   @Override
@@ -198,20 +277,28 @@ public class BufferingDoFnRunner<InputT, OutputT> 
implements DoFnRunner<InputT,
     for (CheckpointIdentifier toBeAcked : allToAck) {
       BufferingElementsHandler bufferingElementsHandler =
           bufferingElementsHandlerFactory.get(toBeAcked.internalId);
-      Iterator<BufferedElement> iterator = 
bufferingElementsHandler.getElements().iterator();
-      boolean hasElements = iterator.hasNext();
-      if (hasElements) {
-        underlying.startBundle();
-      }
-      while (iterator.hasNext()) {
-        BufferedElement bufferedElement = iterator.next();
-        bufferedElement.processWith(underlying);
-      }
-      if (hasElements) {
-        underlying.finishBundle();
+      try (Locker lock = locker != null ? locker.get() : null) {
+        final Iterator<BufferedElement> iterator =
+            bufferingElementsHandler.getElements().iterator();
+        boolean hasElements = iterator.hasNext();
+        if (hasElements) {
+          underlying.startBundle();
+        }
+        while (iterator.hasNext()) {
+          BufferedElement bufferedElement = iterator.next();
+          bufferedElement.processWith(underlying);
+        }
+        if (hasElements) {
+          underlying.finishBundle();
+        }
+        bufferingElementsHandler.clear();
       }
-      bufferingElementsHandler.clear();
     }
+    minBufferedElementTimestamp = Long.MAX_VALUE;
+  }
+
+  public long getOutputWatermarkHold() {
+    return minBufferedElementTimestamp;
   }
 
   private void addToBeAcknowledgedCheckpoint(long checkpointId, int 
internalId) throws Exception {
diff --git 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkRequiresStableInputTest.java
 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkRequiresStableInputTest.java
index 5a71819f9aa..bac201ff56c 100644
--- 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkRequiresStableInputTest.java
+++ 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkRequiresStableInputTest.java
@@ -19,39 +19,51 @@ package org.apache.beam.runners.flink;
 
 import static 
org.apache.beam.sdk.testing.FileChecksumMatcher.fileContentsHaveChecksum;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
 
+import java.util.Collections;
 import java.util.Date;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.TimeUnit;
+import java.util.Optional;
+import java.util.concurrent.Executors;
+import org.apache.beam.model.jobmanagement.v1.JobApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.core.construction.Environments;
+import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.runners.jobsubmission.JobInvocation;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.PipelineResult;
 import org.apache.beam.sdk.RequiresStableInputIT;
 import org.apache.beam.sdk.io.FileSystems;
 import org.apache.beam.sdk.io.fs.ResolveOptions;
 import org.apache.beam.sdk.io.fs.ResourceId;
 import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PortablePipelineOptions;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.state.TimerSpec;
+import org.apache.beam.sdk.state.TimerSpecs;
+import org.apache.beam.sdk.testing.CrashingRunner;
+import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Reshuffle;
 import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.WithKeys;
 import org.apache.beam.sdk.util.FilePatternMatchingShardedFile;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
-import org.apache.flink.api.common.JobID;
-import org.apache.flink.configuration.CheckpointingOptions;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.configuration.RestOptions;
-import org.apache.flink.runtime.jobgraph.JobGraph;
-import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
-import org.apache.flink.runtime.minicluster.MiniCluster;
-import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
-import org.apache.flink.streaming.util.TestStreamEnvironment;
-import org.junit.AfterClass;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ListeningExecutorService;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.MoreExecutors;
+import org.joda.time.Instant;
 import org.junit.BeforeClass;
 import org.junit.ClassRule;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 
@@ -60,50 +72,20 @@ public class FlinkRequiresStableInputTest {
 
   @ClassRule public static TemporaryFolder tempFolder = new TemporaryFolder();
 
-  private static CountDownLatch latch;
-
   private static final String VALUE = "value";
   // SHA-1 hash of string "value"
   private static final String VALUE_CHECKSUM = 
"f32b67c7e26342af42efabc674d441dca0a281c5";
 
-  private static transient MiniCluster flinkCluster;
+  private static ListeningExecutorService flinkJobExecutor;
+  private static final int PARALLELISM = 1;
+  private static final long CHECKPOINT_INTERVAL = 2000L;
+  private static final long FINISH_SOURCE_INTERVAL = 3 * CHECKPOINT_INTERVAL;
 
   @BeforeClass
-  public static void beforeClass() throws Exception {
-    final int parallelism = 1;
-
-    Configuration config = new Configuration();
-    // Avoid port collision in parallel tests
-    config.setInteger(RestOptions.PORT, 0);
-    config.setString(CheckpointingOptions.STATE_BACKEND, "filesystem");
-    // It is necessary to configure the checkpoint directory for the state 
backend,
-    // even though we only create savepoints in this test.
-    config.setString(
-        CheckpointingOptions.CHECKPOINTS_DIRECTORY,
-        "file://" + tempFolder.getRoot().getAbsolutePath());
-    // Checkpoints will go into a subdirectory of this directory
-    config.setString(
-        CheckpointingOptions.SAVEPOINT_DIRECTORY,
-        "file://" + tempFolder.getRoot().getAbsolutePath());
-
-    MiniClusterConfiguration clusterConfig =
-        new MiniClusterConfiguration.Builder()
-            .setConfiguration(config)
-            .setNumTaskManagers(1)
-            .setNumSlotsPerTaskManager(1)
-            .build();
-
-    flinkCluster = new MiniCluster(clusterConfig);
-    flinkCluster.start();
-
-    TestStreamEnvironment.setAsContext(flinkCluster, parallelism);
-  }
-
-  @AfterClass
-  public static void afterClass() throws Exception {
-    TestStreamEnvironment.unsetAsContext();
-    flinkCluster.close();
-    flinkCluster = null;
+  public static void setup() {
+    // Restrict this to only one thread to avoid multiple Flink clusters up at 
the same time
+    // which is not suitable for memory-constraint environments, i.e. Jenkins.
+    flinkJobExecutor = 
MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1));
   }
 
   /**
@@ -121,15 +103,41 @@ public class FlinkRequiresStableInputTest {
    * restore the savepoint to check if we produce impotent results.
    */
   @Test(timeout = 30_000)
-  @Ignore("https://github.com/apache/beam/issues/21333";)
   public void testParDoRequiresStableInput() throws Exception {
-    FlinkPipelineOptions options = FlinkPipelineOptions.defaults();
-    options.setParallelism(1);
-    // We only want to trigger external savepoints but we require
-    // checkpointing to be enabled for @RequiresStableInput
-    options.setCheckpointingInterval(Long.MAX_VALUE);
-    options.setRunner(FlinkRunner.class);
-    options.setStreaming(true);
+    runTest(false);
+  }
+
+  @Test(timeout = 30_000)
+  public void testParDoRequiresStableInputPortable() throws Exception {
+    runTest(true);
+  }
+
+  @Test(timeout = 30_000)
+  public void testParDoRequiresStableInputStateful() throws Exception {
+    testParDoRequiresStableInputStateful(false);
+  }
+
+  @Test(timeout = 30_000)
+  public void testParDoRequiresStableInputStatefulPortable() throws Exception {
+    testParDoRequiresStableInputStateful(true);
+  }
+
+  private void testParDoRequiresStableInputStateful(boolean portable) throws 
Exception {
+    FlinkPipelineOptions opts = getFlinkOptions(portable);
+    
opts.as(FlinkPipelineOptions.class).setShutdownSourcesAfterIdleMs(FINISH_SOURCE_INTERVAL);
+    opts.as(FlinkPipelineOptions.class).setNumberOfExecutionRetries(0);
+    Pipeline pipeline = Pipeline.create(opts);
+    PCollection<Integer> result =
+        pipeline
+            .apply(Create.of(1, 2, 3, 4))
+            .apply(WithKeys.of((Void) null))
+            .apply(ParDo.of(new StableDoFn()));
+    PAssert.that(result).containsInAnyOrder(1, 2, 3, 4);
+    executePipeline(pipeline, portable);
+  }
+
+  private void runTest(boolean portable) throws Exception {
+    FlinkPipelineOptions options = getFlinkOptions(portable);
 
     ResourceId outputDir =
         FileSystems.matchNewResource(tempFolder.getRoot().getAbsolutePath(), 
true)
@@ -149,21 +157,7 @@ public class FlinkRequiresStableInputTest {
 
     Pipeline p = createPipeline(options, singleOutputPrefix, 
multiOutputPrefix);
 
-    // a latch used by the transforms to signal completion
-    latch = new CountDownLatch(2);
-    JobID jobID = executePipeline(p);
-    String savepointDir;
-    do {
-      // Take a savepoint (checkpoint) which will trigger releasing the 
buffered elements
-      // and trigger the latch
-      savepointDir = takeSavepoint(jobID);
-    } while (!latch.await(100, TimeUnit.MILLISECONDS));
-    flinkCluster.cancelJob(jobID).get();
-
-    options.setShutdownSourcesAfterIdleMs(0L);
-    restoreFromSavepoint(p, savepointDir);
-    waitUntilJobIsDone();
-
+    executePipeline(p, portable);
     assertThat(
         new FilePatternMatchingShardedFile(singleOutputPrefix + "*"),
         fileContentsHaveChecksum(VALUE_CHECKSUM));
@@ -172,78 +166,127 @@ public class FlinkRequiresStableInputTest {
         fileContentsHaveChecksum(VALUE_CHECKSUM));
   }
 
-  private JobGraph getJobGraph(Pipeline pipeline) {
-    FlinkRunner flinkRunner = FlinkRunner.fromOptions(pipeline.getOptions());
-    return flinkRunner.getJobGraph(pipeline);
-  }
-
-  private JobID executePipeline(Pipeline pipeline) throws Exception {
-    JobGraph jobGraph = getJobGraph(pipeline);
-    flinkCluster.submitJob(jobGraph).get();
-    return jobGraph.getJobID();
-  }
+  private void executePipeline(Pipeline pipeline, boolean portable) throws 
Exception {
+    if (portable) {
+      RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline);
+      FlinkPipelineOptions flinkOpts = 
pipeline.getOptions().as(FlinkPipelineOptions.class);
+      // execute the pipeline
+      JobInvocation jobInvocation =
+          FlinkJobInvoker.create(null)
+              .createJobInvocation(
+                  "fakeId",
+                  "fakeRetrievalToken",
+                  flinkJobExecutor,
+                  pipelineProto,
+                  flinkOpts,
+                  new FlinkPipelineRunner(flinkOpts, null, 
Collections.emptyList()));
+      jobInvocation.start();
+      while (jobInvocation.getState() != JobApi.JobState.Enum.DONE
+          && jobInvocation.getState() != JobApi.JobState.Enum.FAILED) {
 
-  private String takeSavepoint(JobID jobID) throws Exception {
-    Exception exception = null;
-    // try multiple times because the job might not be ready yet
-    for (int i = 0; i < 10; i++) {
-      try {
-        return MiniClusterCompat.triggerSavepoint(flinkCluster, jobID, null, 
false).get();
-      } catch (Exception e) {
-        exception = e;
-        Thread.sleep(100);
+        Thread.sleep(1000);
       }
+      assertThat(jobInvocation.getState(), equalTo(JobApi.JobState.Enum.DONE));
+    } else {
+      executePipelineLegacy(pipeline);
     }
-    throw exception;
-  }
-
-  private JobID restoreFromSavepoint(Pipeline pipeline, String savepointDir)
-      throws ExecutionException, InterruptedException {
-    JobGraph jobGraph = getJobGraph(pipeline);
-    SavepointRestoreSettings savepointSettings = 
SavepointRestoreSettings.forPath(savepointDir);
-    jobGraph.setSavepointRestoreSettings(savepointSettings);
-    return flinkCluster.submitJob(jobGraph).get().getJobID();
   }
 
-  private void waitUntilJobIsDone() throws InterruptedException, 
ExecutionException {
-    while (flinkCluster.listJobs().get().stream()
-        .anyMatch(message -> message.getJobState().name().equals("RUNNING"))) {
-      Thread.sleep(100);
-    }
+  private void executePipelineLegacy(Pipeline pipeline) {
+    FlinkRunner flinkRunner = FlinkRunner.fromOptions(pipeline.getOptions());
+    PipelineResult.State state = flinkRunner.run(pipeline).waitUntilFinish();
+    assertThat(state, equalTo(PipelineResult.State.DONE));
   }
 
   private static Pipeline createPipeline(
       PipelineOptions options, String singleOutputPrefix, String 
multiOutputPrefix) {
     Pipeline p = Pipeline.create(options);
-
-    SerializableFunction<Void, Void> firstTime =
-        (SerializableFunction<Void, Void>)
-            value -> {
-              latch.countDown();
-              return null;
-            };
-
+    SerializableFunction<Void, Void> sideEffect =
+        ign -> {
+          throw new IllegalStateException("Failing job to test 
@RequiresStableInput");
+        };
     PCollection<String> impulse = p.apply("CreatePCollectionOfOneValue", 
Create.of(VALUE));
     impulse
         .apply(
             "Single-PairWithRandomKey",
             MapElements.via(new RequiresStableInputIT.PairWithRandomKeyFn()))
+        // need Reshuffle due to https://github.com/apache/beam/issues/24655
+        // can be removed once fixed
+        .apply(Reshuffle.of())
         .apply(
             "Single-MakeSideEffectAndThenFail",
             ParDo.of(
                 new RequiresStableInputIT.MakeSideEffectAndThenFailFn(
-                    singleOutputPrefix, firstTime)));
+                    singleOutputPrefix, sideEffect)));
     impulse
         .apply(
             "Multi-PairWithRandomKey",
             MapElements.via(new RequiresStableInputIT.PairWithRandomKeyFn()))
+        // need Reshuffle due to https://github.com/apache/beam/issues/24655
+        // can be removed once fixed
+        .apply(Reshuffle.of())
         .apply(
             "Multi-MakeSideEffectAndThenFail",
             ParDo.of(
                     new RequiresStableInputIT.MakeSideEffectAndThenFailFn(
-                        multiOutputPrefix, firstTime))
+                        multiOutputPrefix, sideEffect))
                 .withOutputTags(new TupleTag<>(), TupleTagList.empty()));
 
     return p;
   }
+
+  private FlinkPipelineOptions getFlinkOptions(boolean portable) {
+    FlinkPipelineOptions options = FlinkPipelineOptions.defaults();
+    options.setParallelism(PARALLELISM);
+    options.setCheckpointingInterval(CHECKPOINT_INTERVAL);
+    options.setShutdownSourcesAfterIdleMs(FINISH_SOURCE_INTERVAL);
+    options.setFinishBundleBeforeCheckpointing(true);
+    options.setMaxBundleTimeMills(100L);
+    options.setStreaming(true);
+    if (portable) {
+      options.setRunner(CrashingRunner.class);
+      options
+          .as(PortablePipelineOptions.class)
+          .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED);
+    } else {
+      options.setRunner(FlinkRunner.class);
+    }
+    return options;
+  }
+
+  private static class StableDoFn extends DoFn<KV<Void, Integer>, Integer> {
+
+    @StateId("state")
+    final StateSpec<BagState<Integer>> stateSpec = StateSpecs.bag();
+
+    @TimerId("flush")
+    final TimerSpec flushSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+    @ProcessElement
+    @RequiresStableInput
+    public void process(
+        @Element KV<Void, Integer> input,
+        @StateId("state") BagState<Integer> buffer,
+        @TimerId("flush") Timer flush,
+        OutputReceiver<Integer> output) {
+
+      // Timers do not to work with stateful stable dofn,
+      // see https://github.com/apache/beam/issues/24662
+      // Once this is resolved, flush the buffer on timer
+      // flush.set(GlobalWindow.INSTANCE.maxTimestamp());
+      // buffer.add(input.getValue());
+      output.output(input.getValue());
+    }
+
+    @OnTimer("flush")
+    public void flush(
+        @Timestamp Instant ts,
+        @StateId("state") BagState<Integer> buffer,
+        OutputReceiver<Integer> output) {
+
+      Optional.ofNullable(buffer.read())
+          .ifPresent(b -> b.forEach(e -> output.outputWithTimestamp(e, ts)));
+      buffer.clear();
+    }
+  }
 }
diff --git 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkRunnerTest.java
 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkRunnerTest.java
index 379d6ee0151..3b2ac2b6fd2 100644
--- 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkRunnerTest.java
+++ 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkRunnerTest.java
@@ -56,8 +56,8 @@ public class FlinkRunnerTest {
     MatcherAssert.assertThat(
         e.getMessage(),
         allOf(
-            StringContains.containsString("System.out: (none)"),
-            StringContains.containsString("System.err: (none)")));
+            StringContains.containsString("System.out: "),
+            StringContains.containsString("System.err: ")));
   }
 
   /** Main method for {@code testEnsureStdoutStdErrIsRestored()}. */
diff --git 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/PortableExecutionTest.java
 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/PortableExecutionTest.java
index 00ca07c1b6f..9a9c4125090 100644
--- 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/PortableExecutionTest.java
+++ 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/PortableExecutionTest.java
@@ -98,6 +98,7 @@ public class PortableExecutionTest implements Serializable {
     options.as(FlinkPipelineOptions.class).setFlinkMaster("[local]");
     options.as(FlinkPipelineOptions.class).setStreaming(isStreaming);
     options.as(FlinkPipelineOptions.class).setParallelism(2);
+    options.as(FlinkPipelineOptions.class).setNumberOfExecutionRetries(0);
     options
         .as(PortablePipelineOptions.class)
         .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED);
diff --git 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java
 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java
index 5ddef5935b2..c3412206c1b 100644
--- 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java
+++ 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java
@@ -2096,7 +2096,10 @@ public class DoFnOperatorTest {
     testHarness.processElement(
         new StreamRecord<>(WindowedValue.valueInGlobalWindow(KV.of("key2", 
"d"))));
 
-    assertThat(Iterables.size(testHarness.getOutput()), is(0));
+    assertThat(
+        testHarness.getOutput() + " should be empty",
+        Iterables.size(testHarness.getOutput()),
+        is(0));
 
     OperatorSubtaskState backup = testHarness.snapshot(0, 0);
     doFnOperator.notifyCheckpointComplete(0L);
diff --git 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
index e0a4c0155aa..87f52baab28 100644
--- 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
+++ 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
@@ -182,6 +182,30 @@ public class ExecutableStageDoFnOperatorTest {
                   .build())
           .build();
 
+  private final ExecutableStagePayload stagePayloadWithStableInput =
+      stagePayload
+          .toBuilder()
+          .setComponents(
+              stagePayload
+                  .getComponents()
+                  .toBuilder()
+                  .putTransforms(
+                      "transform",
+                      RunnerApi.PTransform.newBuilder()
+                          .setSpec(
+                              RunnerApi.FunctionSpec.newBuilder()
+                                  .setUrn(PAR_DO_TRANSFORM_URN)
+                                  .setPayload(
+                                      RunnerApi.ParDoPayload.newBuilder()
+                                          .setRequiresStableInput(true)
+                                          .build()
+                                          .toByteString())
+                                  .build())
+                          .putInputs("input", "input")
+                          .build())
+                  .build())
+          .build();
+
   private final JobInfo jobInfo =
       JobInfo.create("job-id", "job-name", "retrieval-token", 
Struct.getDefaultInstance());
 
@@ -1132,16 +1156,38 @@ public class ExecutableStageDoFnOperatorTest {
     assertNotEquals(operator, clone);
   }
 
+  @Test
+  public void testStableInputApplied() {
+    TupleTag<Integer> mainOutput = new TupleTag<>("main-output");
+    FlinkPipelineOptions options = FlinkPipelineOptions.defaults();
+    options.setCheckpointingInterval(100L);
+    DoFnOperator.MultiOutputOutputManagerFactory<Integer> outputManagerFactory 
=
+        new DoFnOperator.MultiOutputOutputManagerFactory(
+            mainOutput, VoidCoder.of(), new 
SerializablePipelineOptions(options));
+    ExecutableStageDoFnOperator<Integer, Integer> operator =
+        getOperator(
+            mainOutput,
+            Collections.emptyList(),
+            outputManagerFactory,
+            WindowingStrategy.globalDefault(),
+            null,
+            WindowedValue.getFullCoder(StringUtf8Coder.of(), 
GlobalWindow.Coder.INSTANCE),
+            stagePayloadWithStableInput,
+            options);
+
+    assertThat(operator.getRequiresStableInput(), is(true));
+  }
+
   /**
    * Creates a {@link ExecutableStageDoFnOperator}. Sets the runtime context 
to {@link
    * #runtimeContext}. The context factory is mocked to return {@link 
#stageContext} every time. The
    * behavior of the stage context itself is unchanged.
    */
-  @SuppressWarnings("rawtypes")
   private ExecutableStageDoFnOperator getOperator(
       TupleTag<Integer> mainOutput,
       List<TupleTag<?>> additionalOutputs,
       DoFnOperator.MultiOutputOutputManagerFactory<Integer> 
outputManagerFactory) {
+
     return getOperator(
         mainOutput,
         additionalOutputs,
@@ -1151,7 +1197,6 @@ public class ExecutableStageDoFnOperatorTest {
         WindowedValue.getFullCoder(StringUtf8Coder.of(), 
GlobalWindow.Coder.INSTANCE));
   }
 
-  @SuppressWarnings("rawtypes")
   private ExecutableStageDoFnOperator getOperator(
       TupleTag<Integer> mainOutput,
       List<TupleTag<?>> additionalOutputs,
@@ -1160,16 +1205,37 @@ public class ExecutableStageDoFnOperatorTest {
       @Nullable Coder keyCoder,
       Coder windowedInputCoder) {
 
-    FlinkExecutableStageContextFactory contextFactory =
-        Mockito.mock(FlinkExecutableStageContextFactory.class);
-    when(contextFactory.get(any())).thenReturn(stageContext);
-
     final ExecutableStagePayload stagePayload;
     if (keyCoder != null) {
       stagePayload = this.stagePayloadWithUserState;
     } else {
       stagePayload = this.stagePayload;
     }
+    return getOperator(
+        mainOutput,
+        additionalOutputs,
+        outputManagerFactory,
+        windowingStrategy,
+        keyCoder,
+        windowedInputCoder,
+        stagePayload,
+        FlinkPipelineOptions.defaults());
+  }
+
+  @SuppressWarnings("rawtypes")
+  private ExecutableStageDoFnOperator getOperator(
+      TupleTag<Integer> mainOutput,
+      List<TupleTag<?>> additionalOutputs,
+      DoFnOperator.MultiOutputOutputManagerFactory<Integer> 
outputManagerFactory,
+      WindowingStrategy windowingStrategy,
+      @Nullable Coder keyCoder,
+      Coder windowedInputCoder,
+      ExecutableStagePayload stagePayload,
+      FlinkPipelineOptions options) {
+
+    FlinkExecutableStageContextFactory contextFactory =
+        Mockito.mock(FlinkExecutableStageContextFactory.class);
+    when(contextFactory.get(any())).thenReturn(stageContext);
 
     ExecutableStageDoFnOperator<Integer, Integer> operator =
         new ExecutableStageDoFnOperator<>(
@@ -1182,7 +1248,7 @@ public class ExecutableStageDoFnOperatorTest {
             Collections.emptyMap() /* sideInputTagMapping */,
             Collections.emptyList() /* sideInputs */,
             Collections.emptyMap() /* sideInputId mapping */,
-            FlinkPipelineOptions.defaults(),
+            options,
             stagePayload,
             jobInfo,
             contextFactory,
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/RequiresStableInputIT.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/RequiresStableInputIT.java
index 6d06d3350a0..dd24b6bef8f 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/RequiresStableInputIT.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/RequiresStableInputIT.java
@@ -95,9 +95,9 @@ public class RequiresStableInputIT {
 
     public static void writeTextToFileSideEffect(String text, String filename) 
throws IOException {
       ResourceId rid = FileSystems.matchNewResource(filename, false);
-      WritableByteChannel chan = FileSystems.create(rid, "text/plain");
-      chan.write(ByteBuffer.wrap(text.getBytes(StandardCharsets.UTF_8)));
-      chan.close();
+      try (WritableByteChannel chan = FileSystems.create(rid, "text/plain")) {
+        chan.write(ByteBuffer.wrap(text.getBytes(StandardCharsets.UTF_8)));
+      }
     }
   }
 
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py 
b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index 27e4ca4973e..66c5be544e7 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -69,6 +69,10 @@ class 
FlinkRunnerTest(portable_runner_test.PortableRunnerTest):
     super().__init__(*args, **kwargs)
     self.environment_type = None
     self.environment_config = None
+    self.enable_commit = False
+
+  def setUp(self):
+    self.enable_commit = False
 
   @pytest.fixture(autouse=True)
   def parse_options(self, request):
@@ -197,6 +201,11 @@ class 
FlinkRunnerTest(portable_runner_test.PortableRunnerTest):
     options.view_as(PortableOptions).environment_type = self.environment_type
     options.view_as(
         PortableOptions).environment_options = self.environment_options
+    if self.enable_commit:
+      options.view_as(StandardOptions).streaming = True
+      options._all_options['checkpointing_interval'] = 3000
+      options._all_options['shutdown_sources_after_idle_ms'] = 60000
+      options._all_options['number_of_execution_retries'] = 1
 
     return options
 
@@ -224,6 +233,7 @@ class 
FlinkRunnerTest(portable_runner_test.PortableRunnerTest):
     # Nevertheless, we check that the transform is expanded by the
     # ExpansionService and that the pipeline fails during execution.
     with self.assertRaises(Exception) as ctx:
+      self.enable_commit = True
       with self.create_pipeline() as p:
         # pylint: disable=expression-not-assigned
         (
@@ -338,19 +348,9 @@ class FlinkRunnerTestOptimized(FlinkRunnerTest):
 
 
 class FlinkRunnerTestStreaming(FlinkRunnerTest):
-  def __init__(self, *args, **kwargs):
-    super().__init__(*args, **kwargs)
-    self.enable_commit = False
-
-  def setUp(self):
-    self.enable_commit = False
-
   def create_options(self):
     options = super().create_options()
     options.view_as(StandardOptions).streaming = True
-    if self.enable_commit:
-      options._all_options['checkpointing_interval'] = 3000
-      options._all_options['shutdown_sources_after_idle_ms'] = 60000
     return options
 
   def test_callbacks_with_exception(self):

Reply via email to