This is an automated email from the ASF dual-hosted git repository.

stankiewicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new d032f0e3ef4 Add drain support for Dataflow and Flink (#38786)
d032f0e3ef4 is described below

commit d032f0e3ef4f9409d81e231dcef54cafa3e21ba1
Author: Lalit Yadav <[email protected]>
AuthorDate: Wed Jun 17 04:04:50 2026 -0500

    Add drain support for Dataflow and Flink (#38786)
    
    * Add drain support for Dataflow and Flink
    
    * Add PipelineResult drain API
---
 .../runners/flink/FlinkDetachedRunnerResult.java   | 60 +++++++++++++-
 .../beam/runners/flink/FlinkRunnerResult.java      |  6 ++
 .../beam/runners/flink/FlinkRunnerResultTest.java  | 92 ++++++++++++++++++++++
 .../beam/runners/dataflow/DataflowPipelineJob.java | 67 ++++++++++------
 .../runners/dataflow/DataflowPipelineJobTest.java  | 46 +++++++++++
 .../java/org/apache/beam/sdk/PipelineResult.java   | 13 +++
 6 files changed, 255 insertions(+), 29 deletions(-)

diff --git 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkDetachedRunnerResult.java
 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkDetachedRunnerResult.java
index f7d82065b65..b26e865526d 100644
--- 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkDetachedRunnerResult.java
+++ 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkDetachedRunnerResult.java
@@ -18,6 +18,7 @@
 package org.apache.beam.runners.flink;
 
 import java.io.IOException;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.flink.metrics.FlinkMetricContainer;
@@ -25,6 +26,8 @@ import org.apache.beam.sdk.PipelineResult;
 import org.apache.beam.sdk.metrics.MetricResults;
 import org.apache.flink.api.common.JobStatus;
 import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.core.execution.SavepointFormatType;
+import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Duration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -39,6 +42,7 @@ public class FlinkDetachedRunnerResult implements 
PipelineResult {
 
   private JobClient jobClient;
   private int jobCheckIntervalInSecs;
+  private volatile @Nullable CompletableFuture<String> drainSavepointFuture;
 
   FlinkDetachedRunnerResult(JobClient jobClient, int jobCheckIntervalInSecs) {
     this.jobClient = jobClient;
@@ -47,10 +51,25 @@ public class FlinkDetachedRunnerResult implements 
PipelineResult {
 
   @Override
   public State getState() {
+    CompletableFuture<String> drainFuture = drainSavepointFuture;
+    if (drainFuture != null) {
+      try {
+        return getDrainState(drainFuture);
+      } catch (IOException e) {
+        LOG.warn("Failed to drain Flink job. Querying Flink job state 
instead.", e);
+      }
+    }
+    return getFlinkJobState();
+  }
+
+  private State getFlinkJobState() {
     try {
       return toBeamJobState(jobClient.getJobStatus().get());
-    } catch (InterruptedException | ExecutionException e) {
-      throw new RuntimeException("Fail to get flink job state", e);
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      throw new RuntimeException("Failed to get Flink job state", e);
+    } catch (ExecutionException e) {
+      throw new RuntimeException("Failed to get Flink job state", e);
     }
   }
 
@@ -66,7 +85,11 @@ public class FlinkDetachedRunnerResult implements 
PipelineResult {
               .getAccumulators()
               .get()
               .getOrDefault(FlinkMetricContainer.ACCUMULATOR_NAME, new 
MetricsContainerStepMap());
-    } catch (InterruptedException | ExecutionException e) {
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      LOG.warn("Fail to get flink job accumulators", e);
+      return new MetricsContainerStepMap();
+    } catch (ExecutionException e) {
       LOG.warn("Fail to get flink job accumulators", e);
       return new MetricsContainerStepMap();
     }
@@ -76,12 +99,40 @@ public class FlinkDetachedRunnerResult implements 
PipelineResult {
   public State cancel() throws IOException {
     try {
       this.jobClient.cancel().get();
-    } catch (InterruptedException | ExecutionException e) {
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      throw new RuntimeException("Fail to cancel flink job", e);
+    } catch (ExecutionException e) {
       throw new RuntimeException("Fail to cancel flink job", e);
     }
     return getState();
   }
 
+  @Override
+  public synchronized State drain() throws IOException {
+    CompletableFuture<String> drainFuture = drainSavepointFuture;
+    if (drainFuture == null || drainFuture.isCompletedExceptionally()) {
+      drainFuture = this.jobClient.stopWithSavepoint(true, null, 
SavepointFormatType.DEFAULT);
+      drainSavepointFuture = drainFuture;
+    }
+    return getDrainState(drainFuture);
+  }
+
+  private State getDrainState(CompletableFuture<String> drainFuture) throws 
IOException {
+    if (!drainFuture.isDone()) {
+      return State.RUNNING;
+    }
+    try {
+      drainFuture.get();
+      return State.DONE;
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      throw new IOException("Failed to drain Flink job", e);
+    } catch (ExecutionException e) {
+      throw new IOException("Failed to drain Flink job", e.getCause());
+    }
+  }
+
   @Override
   public State waitUntilFinish() {
     return waitUntilFinish(Duration.millis(Long.MAX_VALUE));
@@ -100,6 +151,7 @@ public class FlinkDetachedRunnerResult implements 
PipelineResult {
       try {
         Thread.sleep(jobCheckIntervalInSecs * 1000L);
       } catch (InterruptedException e) {
+        Thread.currentThread().interrupt();
         throw new RuntimeException(e);
       }
     }
diff --git 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java
 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java
index d892049bce4..c0cce5349f2 100644
--- 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java
+++ 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java
@@ -64,6 +64,12 @@ public class FlinkRunnerResult implements PipelineResult {
     return State.DONE;
   }
 
+  @Override
+  public State drain() {
+    // We can only be called here when we are done.
+    return State.DONE;
+  }
+
   @Override
   public State waitUntilFinish() {
     return State.DONE;
diff --git 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkRunnerResultTest.java
 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkRunnerResultTest.java
index ba0981617fe..908d940f5ef 100644
--- 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkRunnerResultTest.java
+++ 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkRunnerResultTest.java
@@ -19,9 +19,20 @@ package org.apache.beam.runners.flink;
 
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.Is.is;
+import static org.junit.Assert.assertSame;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
+import java.io.IOException;
 import java.util.Collections;
+import java.util.concurrent.CompletableFuture;
 import org.apache.beam.sdk.PipelineResult;
+import org.apache.flink.api.common.JobStatus;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.core.execution.SavepointFormatType;
 import org.joda.time.Duration;
 import org.junit.Test;
 
@@ -47,4 +58,85 @@ public class FlinkRunnerResultTest {
     result.cancel();
     assertThat(result.getState(), is(PipelineResult.State.DONE));
   }
+
+  @Test
+  public void testDrainDoneResultDoesNotThrowAnException() throws Exception {
+    FlinkRunnerResult result = new FlinkRunnerResult(Collections.emptyMap(), 
100);
+    assertThat(result.drain(), is(PipelineResult.State.DONE));
+  }
+
+  @Test
+  public void testDetachedDrainReturnsRunningThenDone() throws Exception {
+    JobClient jobClient = mock(JobClient.class);
+    CompletableFuture<String> drainFuture = new CompletableFuture<>();
+    when(jobClient.stopWithSavepoint(true, null, SavepointFormatType.DEFAULT))
+        .thenReturn(drainFuture);
+    FlinkDetachedRunnerResult result = new 
FlinkDetachedRunnerResult(jobClient, 1);
+
+    assertThat(result.drain(), is(PipelineResult.State.RUNNING));
+    assertThat(result.getState(), is(PipelineResult.State.RUNNING));
+
+    drainFuture.complete("savepoint");
+    assertThat(result.getState(), is(PipelineResult.State.DONE));
+    verify(jobClient).stopWithSavepoint(true, null, 
SavepointFormatType.DEFAULT);
+  }
+
+  @Test
+  public void testDetachedDrainFailureThrowsIOException() throws Exception {
+    JobClient jobClient = mock(JobClient.class);
+    CompletableFuture<String> drainFuture = new CompletableFuture<>();
+    RuntimeException failure = new RuntimeException("savepoint failed");
+    drainFuture.completeExceptionally(failure);
+    when(jobClient.stopWithSavepoint(true, null, SavepointFormatType.DEFAULT))
+        .thenReturn(drainFuture);
+    FlinkDetachedRunnerResult result = new 
FlinkDetachedRunnerResult(jobClient, 1);
+
+    try {
+      result.drain();
+      fail("Expected IOException");
+    } catch (IOException e) {
+      assertThat(e.getMessage(), is("Failed to drain Flink job"));
+      assertSame(failure, e.getCause());
+    }
+  }
+
+  @Test
+  public void testDetachedGetStateFallsBackAfterDrainFailure() throws 
Exception {
+    JobClient jobClient = mock(JobClient.class);
+    CompletableFuture<String> drainFuture = new CompletableFuture<>();
+    drainFuture.completeExceptionally(new RuntimeException("savepoint 
failed"));
+    when(jobClient.stopWithSavepoint(true, null, SavepointFormatType.DEFAULT))
+        .thenReturn(drainFuture);
+    
when(jobClient.getJobStatus()).thenReturn(CompletableFuture.completedFuture(JobStatus.RUNNING));
+    FlinkDetachedRunnerResult result = new 
FlinkDetachedRunnerResult(jobClient, 1);
+
+    try {
+      result.drain();
+      fail("Expected IOException");
+    } catch (IOException expected) {
+      assertThat(result.getState(), is(PipelineResult.State.RUNNING));
+    }
+  }
+
+  @Test
+  public void testDetachedDrainRetriesAfterFailure() throws Exception {
+    JobClient jobClient = mock(JobClient.class);
+    CompletableFuture<String> failedDrainFuture = new CompletableFuture<>();
+    failedDrainFuture.completeExceptionally(new RuntimeException("savepoint 
failed"));
+    CompletableFuture<String> retryDrainFuture = new CompletableFuture<>();
+    when(jobClient.stopWithSavepoint(true, null, SavepointFormatType.DEFAULT))
+        .thenReturn(failedDrainFuture, retryDrainFuture);
+    FlinkDetachedRunnerResult result = new 
FlinkDetachedRunnerResult(jobClient, 1);
+
+    try {
+      result.drain();
+      fail("Expected IOException");
+    } catch (IOException expected) {
+      assertThat(result.drain(), is(PipelineResult.State.RUNNING));
+    }
+
+    retryDrainFuture.complete("savepoint");
+    assertThat(result.getState(), is(PipelineResult.State.DONE));
+    verify(jobClient, times(2)).stopWithSavepoint(true, null, 
SavepointFormatType.DEFAULT);
+  }
 }
diff --git 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java
 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java
index 400f161dee2..0d7e5eaf68d 100644
--- 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java
+++ 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java
@@ -432,55 +432,69 @@ public class DataflowPipelineJob implements 
PipelineResult {
   }
 
   private AtomicReference<FutureTask<State>> cancelState = new 
AtomicReference<>();
+  private AtomicReference<FutureTask<State>> drainState = new 
AtomicReference<>();
 
   @SuppressWarnings("Slf4jFormatShouldBeConst")
   @Override
   public State cancel() throws IOException {
-    // Enforce that a cancel() call on the job is done at most once - as
-    // a workaround for Dataflow service's current bugs with multiple
-    // cancellation, where it may sometimes return an error when cancelling
-    // a job that was already cancelled, but still report the job state as
-    // RUNNING.
-    // To partially work around these issues, we absorb duplicate cancel()
-    // calls. This, of course, doesn't address the case when the job terminates
-    // externally almost concurrently to calling cancel(), but at least it
-    // makes it possible to safely call cancel() multiple times and from
-    // multiple threads in one program.
-    FutureTask<State> tentativeCancelTask =
+    return requestJobState(cancelState, "JOB_STATE_CANCELLED", "cancel", 
"Cancel");
+  }
+
+  @Override
+  public State drain() throws IOException {
+    return requestJobState(drainState, "JOB_STATE_DRAINED", "drain", "Drain");
+  }
+
+  @SuppressWarnings("Slf4jFormatShouldBeConst")
+  private State requestJobState(
+      AtomicReference<FutureTask<State>> requestedState,
+      String dataflowRequestedState,
+      String action,
+      String capitalizedAction)
+      throws IOException {
+    // Enforce that a lifecycle request on the job is done at most once. This 
preserves the
+    // existing cancel() behavior and keeps duplicate drain() calls idempotent 
from one client.
+    FutureTask<State> tentativeTask =
         new FutureTask<>(
             () -> {
               Job content = new Job();
               content.setProjectId(getProjectId());
               String currentJobId = getJobId();
               content.setId(currentJobId);
-              content.setRequestedState("JOB_STATE_CANCELLED");
+              content.setRequestedState(dataflowRequestedState);
               try {
                 Job job = dataflowClient.updateJob(currentJobId, content);
                 return MonitoringUtil.toState(job.getCurrentState());
               } catch (IOException e) {
                 State state = getState();
+                String message = e.getMessage();
                 if (state.isTerminal()) {
-                  LOG.warn("Cancel failed because job is already terminated. 
State is {}", state);
+                  LOG.warn(
+                      "{} failed because job is already terminated. State is 
{}",
+                      capitalizedAction,
+                      state);
                   return state;
-                } else if (e.getMessage().contains("has terminated")) {
+                } else if (message != null && message.contains("has 
terminated")) {
                   // This handles the case where the getState() call above 
returns RUNNING but the
-                  // cancel was rejected because the job is in fact done. 
Hopefully, someday we can
+                  // request was rejected because the job is in fact done. 
Hopefully, someday we can
                   // delete this code if there is better consistency between 
the State and whether
-                  // Cancel succeeds.
+                  // lifecycle requests succeed.
                   //
                   // Example message:
                   //    Workflow modification failed. Causes: 
(7603adc9e9bff51e): Cannot perform
                   //    operation 'cancel' on Job: 
2017-04-01_22_50_59-9269855660514862348. Job has
                   //    terminated in state SUCCESS: Workflow job:
                   //    2017-04-01_22_50_59-9269855660514862348 succeeded.
-                  LOG.warn("Cancel failed because job is already terminated.", 
e);
+                  LOG.warn("{} failed because job is already terminated.", 
capitalizedAction, e);
                   return state;
                 } else {
                   String errorMsg =
                       String.format(
-                          "Failed to cancel job in state %s, "
-                              + "please go to the Developers Console to cancel 
it manually: %s",
+                          "Failed to %s job in state %s, "
+                              + "please go to the Developers Console to %s it 
manually: %s",
+                          action,
                           state,
+                          action,
                           MonitoringUtil.getJobMonitoringPageURL(
                               getProjectId(), getRegion(), getJobId()));
                   LOG.warn(errorMsg);
@@ -488,14 +502,17 @@ public class DataflowPipelineJob implements 
PipelineResult {
                 }
               }
             });
-    if (cancelState.compareAndSet(null, tentativeCancelTask)) {
-      // This thread should perform cancellation, while others will
-      // only wait for the result.
-      cancelState.get().run();
+    if (requestedState.compareAndSet(null, tentativeTask)) {
+      // This thread should perform the lifecycle request, while others will 
only wait for the
+      // result.
+      requestedState.get().run();
     }
     try {
-      return cancelState.get().get();
-    } catch (InterruptedException | ExecutionException e) {
+      return requestedState.get().get();
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      throw new IOException(e);
+    } catch (ExecutionException e) {
       throw new IOException(e);
     }
   }
diff --git 
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java
 
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java
index 54ba10df9d1..4b088eb41a7 100644
--- 
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java
+++ 
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java
@@ -406,6 +406,26 @@ public class DataflowPipelineJobTest {
     verifyNoMoreInteractions(mockJobs);
   }
 
+  @Test
+  public void testDrainUnterminatedJobThatSucceeds() throws IOException {
+    Dataflow.Projects.Locations.Jobs.Update update =
+        mock(Dataflow.Projects.Locations.Jobs.Update.class);
+    when(mockJobs.update(eq(PROJECT_ID), eq(REGION_ID), eq(JOB_ID), 
any(Job.class)))
+        .thenReturn(update);
+    when(update.execute()).thenReturn(new 
Job().setCurrentState("JOB_STATE_DRAINING"));
+
+    DataflowPipelineJob job =
+        new DataflowPipelineJob(DataflowClient.create(options), JOB_ID, 
options, null);
+
+    assertEquals(State.RUNNING, job.drain());
+    Job content = new Job();
+    content.setProjectId(PROJECT_ID);
+    content.setId(JOB_ID);
+    content.setRequestedState("JOB_STATE_DRAINED");
+    verify(mockJobs).update(eq(PROJECT_ID), eq(REGION_ID), eq(JOB_ID), 
eq(content));
+    verifyNoMoreInteractions(mockJobs);
+  }
+
   @Test
   public void testCancelUnterminatedJobThatFails() throws IOException {
     Dataflow.Projects.Locations.Jobs.Get statusRequest =
@@ -432,6 +452,32 @@ public class DataflowPipelineJobTest {
     job.cancel();
   }
 
+  @Test
+  public void testCancelUnterminatedJobWithNullFailureMessage() throws 
IOException {
+    Dataflow.Projects.Locations.Jobs.Get statusRequest =
+        mock(Dataflow.Projects.Locations.Jobs.Get.class);
+
+    Job statusResponse = new Job();
+    statusResponse.setCurrentState("JOB_STATE_RUNNING");
+    when(mockJobs.get(PROJECT_ID, REGION_ID, 
JOB_ID)).thenReturn(statusRequest);
+    when(statusRequest.execute()).thenReturn(statusResponse);
+
+    Dataflow.Projects.Locations.Jobs.Update update =
+        mock(Dataflow.Projects.Locations.Jobs.Update.class);
+    when(mockJobs.update(eq(PROJECT_ID), eq(REGION_ID), eq(JOB_ID), 
any(Job.class)))
+        .thenReturn(update);
+    when(update.execute()).thenThrow(new IOException());
+
+    DataflowPipelineJob job =
+        new DataflowPipelineJob(DataflowClient.create(options), JOB_ID, 
options, null);
+
+    thrown.expect(IOException.class);
+    thrown.expectMessage(
+        "Failed to cancel job in state RUNNING, "
+            + "please go to the Developers Console to cancel it manually:");
+    job.cancel();
+  }
+
   /**
    * Test that {@link DataflowPipelineJob#cancel} doesn't throw if the 
Dataflow service returns
    * non-terminal state even though the cancel API call failed, which can 
happen in practice.
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java
index 91313f3924a..46cca7833e5 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java
@@ -43,6 +43,19 @@ public interface PipelineResult {
    */
   State cancel() throws IOException;
 
+  /**
+   * Drains the pipeline execution.
+   *
+   * <p>Draining requests that the runner stop accepting new input and finish 
processing data that
+   * has already entered the pipeline.
+   *
+   * @throws IOException if there is a problem executing the drain request.
+   * @throws UnsupportedOperationException if the runner does not support 
draining.
+   */
+  default State drain() throws IOException {
+    throw new UnsupportedOperationException("Runner does not support 
draining.");
+  }
+
   /**
    * Waits until the pipeline finishes and returns the final status. It times 
out after the given
    * duration.

Reply via email to