This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new f437a783e19 Tune maximum thread count for streaming dataflow worker 
executor dynamically. (#30439)
f437a783e19 is described below

commit f437a783e191592917bd417495bc6317142e6a43
Author: Melody Shen <[email protected]>
AuthorDate: Fri Apr 5 02:12:23 2024 -0700

    Tune maximum thread count for streaming dataflow worker executor 
dynamically. (#30439)
    
    Workers will read the StreamignScalingReportResponse from worker messages 
and
    configure the executor pool size based on the specified value.
---
 .../org/apache/beam/gradle/BeamModulePlugin.groovy |   2 +-
 .../dataflow/worker/DataflowWorkUnitClient.java    |  10 +-
 .../runners/dataflow/worker/WorkUnitClient.java    |   4 +-
 .../harness/StreamingWorkerStatusReporter.java     |  57 ++++-
 .../dataflow/worker/util/BoundedQueueExecutor.java |  76 ++++--
 .../worker/DataflowWorkUnitClientTest.java         |  13 +-
 .../harness/StreamingWorkerStatusReporterTest.java | 100 ++++++++
 .../worker/util/BoundedQueueExecutorTest.java      | 268 +++++++++++++++++++++
 8 files changed, 503 insertions(+), 27 deletions(-)

diff --git 
a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy 
b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 8be8d73fbcb..d4a16325641 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -732,7 +732,7 @@ class BeamModulePlugin implements Plugin<Project> {
         google_api_common                           : 
"com.google.api:api-common", // google_cloud_platform_libraries_bom sets version
         google_api_services_bigquery                : 
"com.google.apis:google-api-services-bigquery:v2-rev20240124-2.0.0",  // 
[bomupgrader] sets version
         google_api_services_cloudresourcemanager    : 
"com.google.apis:google-api-services-cloudresourcemanager:v1-rev20240128-2.0.0",
  // [bomupgrader] sets version
-        google_api_services_dataflow                : 
"com.google.apis:google-api-services-dataflow:v1b3-rev20240113-$google_clients_version",
+        google_api_services_dataflow                : 
"com.google.apis:google-api-services-dataflow:v1b3-rev20240218-$google_clients_version",
         google_api_services_healthcare              : 
"com.google.apis:google-api-services-healthcare:v1-rev20240130-$google_clients_version",
         google_api_services_pubsub                  : 
"com.google.apis:google-api-services-pubsub:v1-rev20220904-$google_clients_version",
         google_api_services_storage                 : 
"com.google.apis:google-api-services-storage:v1-rev20240205-2.0.0",  // 
[bomupgrader] sets version
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java
index f3caa8d0f3a..af8e7dd50c9 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java
@@ -39,6 +39,7 @@ import com.google.api.services.dataflow.model.WorkItem;
 import com.google.api.services.dataflow.model.WorkItemServiceState;
 import com.google.api.services.dataflow.model.WorkItemStatus;
 import com.google.api.services.dataflow.model.WorkerMessage;
+import com.google.api.services.dataflow.model.WorkerMessageResponse;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -312,7 +313,8 @@ class DataflowWorkUnitClient implements WorkUnitClient {
    * perworkermetrics with this path.
    */
   @Override
-  public void reportWorkerMessage(List<WorkerMessage> messages) throws 
IOException {
+  public List<WorkerMessageResponse> reportWorkerMessage(List<WorkerMessage> 
messages)
+      throws IOException {
     SendWorkerMessagesRequest request =
         new SendWorkerMessagesRequest()
             .setLocation(options.getRegion())
@@ -327,6 +329,10 @@ class DataflowWorkUnitClient implements WorkUnitClient {
       logger.warn("Worker Message response is null");
       throw new IOException("Got null Worker Message response");
     }
-    // Currently no response is expected
+    if (result.getWorkerMessageResponses() == null) {
+      logger.debug("Worker Message response is empty.");
+      return Collections.emptyList();
+    }
+    return result.getWorkerMessageResponses();
   }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java
index d75d91d0088..26b1dc55ead 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java
@@ -23,6 +23,7 @@ import com.google.api.services.dataflow.model.WorkItem;
 import com.google.api.services.dataflow.model.WorkItemServiceState;
 import com.google.api.services.dataflow.model.WorkItemStatus;
 import com.google.api.services.dataflow.model.WorkerMessage;
+import com.google.api.services.dataflow.model.WorkerMessageResponse;
 import java.io.IOException;
 import java.util.List;
 import java.util.Optional;
@@ -75,6 +76,7 @@ public interface WorkUnitClient {
    * perworkermetrics with this path.
    *
    * @param msg the WorkerMessages to report
+   * @return a list of {@link WorkerMessageResponse}
    */
-  void reportWorkerMessage(List<WorkerMessage> messages) throws IOException;
+  List<WorkerMessageResponse> reportWorkerMessage(List<WorkerMessage> 
messages) throws IOException;
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java
index 409f0337eeb..8e950546ae6 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java
@@ -23,8 +23,10 @@ import com.google.api.services.dataflow.model.CounterUpdate;
 import com.google.api.services.dataflow.model.PerStepNamespaceMetrics;
 import com.google.api.services.dataflow.model.PerWorkerMetrics;
 import com.google.api.services.dataflow.model.StreamingScalingReport;
+import com.google.api.services.dataflow.model.StreamingScalingReportResponse;
 import com.google.api.services.dataflow.model.WorkItemStatus;
 import com.google.api.services.dataflow.model.WorkerMessage;
+import com.google.api.services.dataflow.model.WorkerMessageResponse;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -34,6 +36,7 @@ import java.util.Optional;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.Function;
 import java.util.function.Supplier;
@@ -70,6 +73,8 @@ public final class StreamingWorkerStatusReporter {
   private static final String GLOBAL_WORKER_UPDATE_REPORTER_THREAD = 
"GlobalWorkerUpdates";
 
   private final boolean publishCounters;
+  private final int initialMaxThreadCount;
+  private final int initialMaxBundlesOutstanding;
   private final WorkUnitClient dataflowServiceClient;
   private final Supplier<Long> windmillQuotaThrottleTime;
   private final Supplier<Collection<StageInfo>> allStageInfo;
@@ -78,6 +83,7 @@ public final class StreamingWorkerStatusReporter {
   private final MemoryMonitor memoryMonitor;
   private final BoundedQueueExecutor workExecutor;
   private final AtomicLong previousTimeAtMaxThreads;
+  private final AtomicInteger maxThreadCountOverride;
   private final ScheduledExecutorService globalWorkerUpdateReporter;
   private final ScheduledExecutorService workerMessageReporter;
 
@@ -99,7 +105,10 @@ public final class StreamingWorkerStatusReporter {
     this.streamingCounters = streamingCounters;
     this.memoryMonitor = memoryMonitor;
     this.workExecutor = workExecutor;
+    this.initialMaxThreadCount = workExecutor.getMaximumPoolSize();
+    this.initialMaxBundlesOutstanding = 
workExecutor.maximumElementsOutstanding();
     this.previousTimeAtMaxThreads = new AtomicLong();
+    this.maxThreadCountOverride = new AtomicInteger();
     this.globalWorkerUpdateReporter = 
executorFactory.apply(GLOBAL_WORKER_UPDATE_REPORTER_THREAD);
     this.workerMessageReporter = 
executorFactory.apply(WORKER_MESSAGE_REPORTER_THREAD);
   }
@@ -299,9 +308,12 @@ public final class StreamingWorkerStatusReporter {
     }
   }
 
-  private void reportPeriodicWorkerMessage() {
+  @VisibleForTesting
+  public void reportPeriodicWorkerMessage() {
     try {
-      dataflowServiceClient.reportWorkerMessage(createWorkerMessage());
+      List<WorkerMessageResponse> workerMessageResponses =
+          dataflowServiceClient.reportWorkerMessage(createWorkerMessage());
+      
readAndSaveWorkerMessageResponseForStreamingScalingReportResponse(workerMessageResponses);
     } catch (IOException e) {
       LOG.warn("Failed to send worker messages", e);
     } catch (Exception e) {
@@ -346,6 +358,47 @@ public final class StreamingWorkerStatusReporter {
         
dataflowServiceClient.createWorkerMessageFromPerWorkerMetrics(perWorkerMetrics));
   }
 
+  private void 
readAndSaveWorkerMessageResponseForStreamingScalingReportResponse(
+      List<WorkerMessageResponse> responses) {
+    Optional<StreamingScalingReportResponse> streamingScalingReportResponse = 
Optional.empty();
+    for (WorkerMessageResponse response : responses) {
+      if (response.getStreamingScalingReportResponse() != null) {
+        streamingScalingReportResponse = 
Optional.of(response.getStreamingScalingReportResponse());
+      }
+    }
+    if (streamingScalingReportResponse.isPresent()) {
+      int oldMaximumThreadCount = getMaxThreads();
+      
maxThreadCountOverride.set(streamingScalingReportResponse.get().getMaximumThreadCount());
+      int newMaximumThreadCount = getMaxThreads();
+      if (newMaximumThreadCount != oldMaximumThreadCount) {
+        LOG.info(
+            "Setting maximum thread count to {}, old value is {}",
+            newMaximumThreadCount,
+            oldMaximumThreadCount);
+        workExecutor.setMaximumPoolSize(newMaximumThreadCount, 
getMaxBundlesOutstanding());
+      }
+    }
+  }
+
+  private int getMaxThreads() {
+    int currentMaxThreadCountOverride = maxThreadCountOverride.get();
+    if (currentMaxThreadCountOverride != 0) {
+      return currentMaxThreadCountOverride;
+    }
+    return initialMaxThreadCount;
+  }
+
+  private int getMaxBundlesOutstanding() {
+    int currentMaxThreadCountOverride = maxThreadCountOverride.get();
+    if (currentMaxThreadCountOverride != 0) {
+      return currentMaxThreadCountOverride + 100;
+    }
+    if (initialMaxBundlesOutstanding > 0) {
+      return initialMaxBundlesOutstanding;
+    }
+    return getMaxThreads() + 100;
+  }
+
   @VisibleForTesting
   public void reportPeriodicWorkerUpdates() {
     updateVMMetrics();
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
index f7f6fd91a8c..9a481169350 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
@@ -21,7 +21,7 @@ import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.ThreadFactory;
 import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.concurrent.GuardedBy;
 import org.apache.beam.runners.dataflow.worker.streaming.Work;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor.Guard;
@@ -32,15 +32,26 @@ import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurren
 })
 public class BoundedQueueExecutor {
   private final ThreadPoolExecutor executor;
-  private final int maximumElementsOutstanding;
   private final long maximumBytesOutstanding;
-  private final int maximumPoolSize;
 
+  // Used to guard elementsOutstanding and bytesOutstanding.
   private final Monitor monitor = new Monitor();
   private int elementsOutstanding = 0;
   private long bytesOutstanding = 0;
-  private final AtomicInteger activeCount = new AtomicInteger();
+
+  @GuardedBy("this")
+  private int maximumElementsOutstanding;
+
+  @GuardedBy("this")
+  private int activeCount;
+
+  @GuardedBy("this")
+  private int maximumPoolSize;
+
+  @GuardedBy("this")
   private long startTimeMaxActiveThreadsUsed;
+
+  @GuardedBy("this")
   private long totalTimeMaxActiveThreadsUsed;
 
   public BoundedQueueExecutor(
@@ -62,8 +73,8 @@ public class BoundedQueueExecutor {
           @Override
           protected void beforeExecute(Thread t, Runnable r) {
             super.beforeExecute(t, r);
-            synchronized (this) {
-              if (activeCount.getAndIncrement() >= maximumPoolSize - 1) {
+            synchronized (BoundedQueueExecutor.this) {
+              if (++activeCount >= maximumPoolSize && 
startTimeMaxActiveThreadsUsed == 0) {
                 startTimeMaxActiveThreadsUsed = System.currentTimeMillis();
               }
             }
@@ -72,8 +83,8 @@ public class BoundedQueueExecutor {
           @Override
           protected void afterExecute(Runnable r, Throwable t) {
             super.afterExecute(r, t);
-            synchronized (this) {
-              if (activeCount.getAndDecrement() == maximumPoolSize) {
+            synchronized (BoundedQueueExecutor.this) {
+              if (--activeCount < maximumPoolSize && 
startTimeMaxActiveThreadsUsed > 0) {
                 totalTimeMaxActiveThreadsUsed +=
                     (System.currentTimeMillis() - 
startTimeMaxActiveThreadsUsed);
                 startTimeMaxActiveThreadsUsed = 0;
@@ -95,16 +106,31 @@ public class BoundedQueueExecutor {
           public boolean isSatisfied() {
             return elementsOutstanding == 0
                 || (bytesAvailable() >= workBytes
-                    && elementsOutstanding < maximumElementsOutstanding);
+                    && elementsOutstanding < maximumElementsOutstanding());
           }
         });
-    executeLockHeld(work, workBytes);
+    executeMonitorHeld(work, workBytes);
   }
 
   // Forcibly add something to the queue, ignoring the length limit.
   public void forceExecute(Runnable work, long workBytes) {
     monitor.enter();
-    executeLockHeld(work, workBytes);
+    executeMonitorHeld(work, workBytes);
+  }
+
+  // Set the maximum/core pool size of the executor.
+  public synchronized void setMaximumPoolSize(int maximumPoolSize, int 
maximumElementsOutstanding) {
+    // For ThreadPoolExecutor, the maximum pool size should always greater 
than or equal to core
+    // pool size.
+    if (maximumPoolSize > executor.getCorePoolSize()) {
+      executor.setMaximumPoolSize(maximumPoolSize);
+      executor.setCorePoolSize(maximumPoolSize);
+    } else {
+      executor.setCorePoolSize(maximumPoolSize);
+      executor.setMaximumPoolSize(maximumPoolSize);
+    }
+    this.maximumPoolSize = maximumPoolSize;
+    this.maximumElementsOutstanding = maximumElementsOutstanding;
   }
 
   public void shutdown() throws InterruptedException {
@@ -118,31 +144,41 @@ public class BoundedQueueExecutor {
     return executor.getQueue().isEmpty();
   }
 
-  public long allThreadsActiveTime() {
+  public synchronized long allThreadsActiveTime() {
     return totalTimeMaxActiveThreadsUsed;
   }
 
-  public int activeCount() {
-    return activeCount.intValue();
+  public synchronized int activeCount() {
+    return activeCount;
   }
 
   public long bytesOutstanding() {
-    return bytesOutstanding;
+    monitor.enter();
+    try {
+      return bytesOutstanding;
+    } finally {
+      monitor.leave();
+    }
   }
 
   public int elementsOutstanding() {
-    return elementsOutstanding;
+    monitor.enter();
+    try {
+      return elementsOutstanding;
+    } finally {
+      monitor.leave();
+    }
   }
 
   public long maximumBytesOutstanding() {
     return maximumBytesOutstanding;
   }
 
-  public int maximumElementsOutstanding() {
+  public synchronized int maximumElementsOutstanding() {
     return maximumElementsOutstanding;
   }
 
-  public final int getMaximumPoolSize() {
+  public synchronized int getMaximumPoolSize() {
     return maximumPoolSize;
   }
 
@@ -163,7 +199,7 @@ public class BoundedQueueExecutor {
       builder.append("Work Queue Size: ");
       builder.append(elementsOutstanding);
       builder.append("/");
-      builder.append(maximumElementsOutstanding);
+      builder.append(maximumElementsOutstanding());
       builder.append("<br>/n");
 
       builder.append("Work Queue Bytes: ");
@@ -178,7 +214,7 @@ public class BoundedQueueExecutor {
     }
   }
 
-  private void executeLockHeld(Runnable work, long workBytes) {
+  private void executeMonitorHeld(Runnable work, long workBytes) {
     bytesOutstanding += workBytes;
     ++elementsOutstanding;
     monitor.leave();
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java
index fac56890f49..85d79e6be3c 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java
@@ -34,10 +34,13 @@ import 
com.google.api.services.dataflow.model.SendWorkerMessagesRequest;
 import com.google.api.services.dataflow.model.SendWorkerMessagesResponse;
 import com.google.api.services.dataflow.model.SeqMapTask;
 import com.google.api.services.dataflow.model.StreamingScalingReport;
+import com.google.api.services.dataflow.model.StreamingScalingReportResponse;
 import com.google.api.services.dataflow.model.WorkItem;
 import com.google.api.services.dataflow.model.WorkerMessage;
+import com.google.api.services.dataflow.model.WorkerMessageResponse;
 import java.io.IOException;
 import java.util.Collections;
+import java.util.List;
 import java.util.Optional;
 import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions;
 import 
org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC;
@@ -253,6 +256,12 @@ public class DataflowWorkUnitClientTest {
     MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
     response.setContentType(Json.MEDIA_TYPE);
     SendWorkerMessagesResponse workerMessage = new 
SendWorkerMessagesResponse();
+    StreamingScalingReportResponse streamingScalingReportResponse =
+        new StreamingScalingReportResponse().setMaximumThreadCount(10);
+    WorkerMessageResponse workerMessageResponse =
+        new WorkerMessageResponse()
+            .setStreamingScalingReportResponse(streamingScalingReportResponse);
+    
workerMessage.setWorkerMessageResponses(Collections.singletonList(workerMessageResponse));
     workerMessage.setFactory(Transport.getJsonFactory());
     response.setContent(workerMessage.toPrettyString());
 
@@ -271,12 +280,14 @@ public class DataflowWorkUnitClientTest {
             .setMaximumBundleCount(5)
             .setMaximumBytes(6L);
     WorkerMessage msg = 
client.createWorkerMessageFromStreamingScalingReport(activeThreadsReport);
-    client.reportWorkerMessage(Collections.singletonList(msg));
+    List<WorkerMessageResponse> responses =
+        client.reportWorkerMessage(Collections.singletonList(msg));
 
     SendWorkerMessagesRequest actualRequest =
         Transport.getJsonFactory()
             .fromString(request.getContentAsString(), 
SendWorkerMessagesRequest.class);
     assertEquals(ImmutableList.of(msg), actualRequest.getWorkerMessages());
+    assertEquals(ImmutableList.of(workerMessageResponse), responses);
   }
 
   @Test
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java
new file mode 100644
index 00000000000..bdf0f0031d6
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming.harness;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.api.services.dataflow.model.StreamingScalingReportResponse;
+import com.google.api.services.dataflow.model.WorkerMessageResponse;
+import java.util.Collections;
+import java.util.concurrent.Executors;
+import org.apache.beam.runners.dataflow.worker.WorkUnitClient;
+import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
+import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mockito;
+
+@RunWith(JUnit4.class)
+public class StreamingWorkerStatusReporterTest {
+  private final long DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME = 1000;
+
+  private BoundedQueueExecutor mockExecutor;
+  private WorkUnitClient mockWorkUnitClient;
+  private FailureTracker mockFailureTracker;
+  private MemoryMonitor mockMemoryMonitor;
+
+  @Before
+  public void setUp() {
+    this.mockExecutor = mock(BoundedQueueExecutor.class);
+    this.mockWorkUnitClient = mock(WorkUnitClient.class);
+    this.mockFailureTracker = mock(FailureTracker.class);
+    this.mockMemoryMonitor = mock(MemoryMonitor.class);
+  }
+
+  @Test
+  public void testOverrideMaximumThreadCount() throws Exception {
+    StreamingWorkerStatusReporter reporter =
+        StreamingWorkerStatusReporter.forTesting(
+            true,
+            mockWorkUnitClient,
+            () -> DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME,
+            () -> Collections.emptyList(),
+            mockFailureTracker,
+            StreamingCounters.create(),
+            mockMemoryMonitor,
+            mockExecutor,
+            (threadName) -> Executors.newSingleThreadScheduledExecutor());
+    StreamingScalingReportResponse streamingScalingReportResponse =
+        new StreamingScalingReportResponse().setMaximumThreadCount(10);
+    WorkerMessageResponse workerMessageResponse =
+        new WorkerMessageResponse()
+            .setStreamingScalingReportResponse(streamingScalingReportResponse);
+    when(mockWorkUnitClient.reportWorkerMessage(any()))
+        .thenReturn(Collections.singletonList(workerMessageResponse));
+    reporter.reportPeriodicWorkerMessage();
+    verify(mockExecutor).setMaximumPoolSize(10, 110);
+  }
+
+  @Test
+  public void testHandleEmptyWorkerMessageResponse() throws Exception {
+    StreamingWorkerStatusReporter reporter =
+        StreamingWorkerStatusReporter.forTesting(
+            true,
+            mockWorkUnitClient,
+            () -> DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME,
+            () -> Collections.emptyList(),
+            mockFailureTracker,
+            StreamingCounters.create(),
+            mockMemoryMonitor,
+            mockExecutor,
+            (threadName) -> Executors.newSingleThreadScheduledExecutor());
+    WorkerMessageResponse workerMessageResponse = new WorkerMessageResponse();
+    when(mockWorkUnitClient.reportWorkerMessage(any()))
+        .thenReturn(Collections.singletonList(workerMessageResponse));
+    reporter.reportPeriodicWorkerMessage();
+    verify(mockExecutor, Mockito.times(0)).setMaximumPoolSize(anyInt(), 
anyInt());
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
new file mode 100644
index 00000000000..c0620952ef9
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.util;
+
+import static org.hamcrest.Matchers.greaterThan;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link 
org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor}. */
+@RunWith(JUnit4.class)
+// TODO(https://github.com/apache/beam/issues/21230): Remove when new version 
of errorprone is
+// released (2.11.0)
+@SuppressWarnings("unused")
+public class BoundedQueueExecutorTest {
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(300);
+  private static final long MAXIMUM_BYTES_OUTSTANDING = 10000000;
+  private static final int DEFAULT_MAX_THREADS = 2;
+  private static final int DEFAULT_THREAD_EXPIRATION_SEC = 60;
+
+  private BoundedQueueExecutor executor;
+
+  private Runnable createSleepProcessWorkFn(CountDownLatch start, 
CountDownLatch stop) {
+    Runnable runnable =
+        () -> {
+          start.countDown();
+          try {
+            stop.await();
+          } catch (Exception e) {
+            throw new RuntimeException(e);
+          }
+        };
+    return runnable;
+  }
+
+  @Before
+  public void setUp() {
+    this.executor =
+        new BoundedQueueExecutor(
+            DEFAULT_MAX_THREADS,
+            DEFAULT_THREAD_EXPIRATION_SEC,
+            TimeUnit.SECONDS,
+            DEFAULT_MAX_THREADS + 100,
+            MAXIMUM_BYTES_OUTSTANDING,
+            new ThreadFactoryBuilder()
+                .setNameFormat("DataflowWorkUnits-%d")
+                .setDaemon(true)
+                .build());
+  }
+
+  @Test
+  public void testScheduleWorkWhenExceedMaximumPoolSize() throws Exception {
+    CountDownLatch processStart1 = new CountDownLatch(1);
+    CountDownLatch processStop1 = new CountDownLatch(1);
+    CountDownLatch processStart2 = new CountDownLatch(1);
+    CountDownLatch processStop2 = new CountDownLatch(1);
+    CountDownLatch processStart3 = new CountDownLatch(1);
+    CountDownLatch processStop3 = new CountDownLatch(1);
+    Runnable m1 = createSleepProcessWorkFn(processStart1, processStop1);
+    Runnable m2 = createSleepProcessWorkFn(processStart2, processStop2);
+    Runnable m3 = createSleepProcessWorkFn(processStart3, processStop3);
+
+    executor.execute(m1, 1);
+    processStart1.await();
+    executor.execute(m2, 1);
+    processStart2.await();
+    // m1 and m2 have started and all threads are occupied so m3 will be 
queued and not executed.
+    executor.execute(m3, 1);
+    assertFalse(processStart3.await(1000, TimeUnit.MILLISECONDS));
+    assertFalse(executor.executorQueueIsEmpty());
+
+    // Stop m1 so there is an available thread for m3 to run.
+    processStop1.countDown();
+    processStart3.await();
+    // m3 started.
+    assertTrue(executor.executorQueueIsEmpty());
+    processStop2.countDown();
+    processStop3.countDown();
+    executor.shutdown();
+  }
+
+  @Test
+  public void testScheduleWorkWhenExceedMaximumBytesOutstanding() throws 
Exception {
+    CountDownLatch processStart1 = new CountDownLatch(1);
+    CountDownLatch processStop1 = new CountDownLatch(1);
+    CountDownLatch processStart2 = new CountDownLatch(1);
+    CountDownLatch processStop2 = new CountDownLatch(1);
+    Runnable m1 = createSleepProcessWorkFn(processStart1, processStop1);
+    Runnable m2 = createSleepProcessWorkFn(processStart2, processStop2);
+
+    executor.execute(m1, 10000000);
+    processStart1.await();
+    // m1 has started and reached the maximumBytesOutstanding. Though the 
executor has available
+    // threads, the new task will be blocked until the bytes are available.
+    // Start a new thread since executor.execute() is a blocking function.
+    Thread m2Runner =
+        new Thread(
+            () -> {
+              executor.execute(m2, 1000);
+            });
+    m2Runner.start();
+    assertFalse(processStart2.await(1000, TimeUnit.MILLISECONDS));
+    // m2 will wait for monitor instead of being queued.
+    assertEquals(Thread.State.WAITING, m2Runner.getState());
+    assertTrue(executor.executorQueueIsEmpty());
+
+    // Stop m1 so there are available bytes for m2 to run.
+    processStop1.countDown();
+    processStart2.await();
+    // m2 started.
+    assertEquals(Thread.State.TERMINATED, m2Runner.getState());
+    processStop2.countDown();
+    executor.shutdown();
+  }
+
+  @Test
+  public void testOverrideMaximumPoolSize() throws Exception {
+    CountDownLatch processStart1 = new CountDownLatch(1);
+    CountDownLatch processStart2 = new CountDownLatch(1);
+    CountDownLatch processStart3 = new CountDownLatch(1);
+    CountDownLatch stop = new CountDownLatch(1);
+    Runnable m1 = createSleepProcessWorkFn(processStart1, stop);
+    Runnable m2 = createSleepProcessWorkFn(processStart2, stop);
+    Runnable m3 = createSleepProcessWorkFn(processStart3, stop);
+
+    // Initial state.
+    assertEquals(0, executor.activeCount());
+    assertEquals(2, executor.getMaximumPoolSize());
+
+    // m1 and m2 are accepted.
+    executor.execute(m1, 1);
+    processStart1.await();
+    assertEquals(1, executor.activeCount());
+    executor.execute(m2, 1);
+    processStart2.await();
+    assertEquals(2, executor.activeCount());
+
+    // Max pool size was reached so new work is queued.
+    executor.execute(m3, 1);
+    assertFalse(processStart3.await(1000, TimeUnit.MILLISECONDS));
+
+    // Increase the max thread count
+    executor.setMaximumPoolSize(3, 103);
+    assertEquals(3, executor.getMaximumPoolSize());
+
+    // m3 is accepted
+    processStart3.await();
+    assertEquals(3, executor.activeCount());
+
+    stop.countDown();
+    executor.shutdown();
+  }
+
+  @Test
+  public void testRecordTotalTimeMaxActiveThreadsUsed() throws Exception {
+    CountDownLatch processStart1 = new CountDownLatch(1);
+    CountDownLatch processStart2 = new CountDownLatch(1);
+    CountDownLatch processStart3 = new CountDownLatch(1);
+    CountDownLatch stop = new CountDownLatch(1);
+    Runnable m1 = createSleepProcessWorkFn(processStart1, stop);
+    Runnable m2 = createSleepProcessWorkFn(processStart2, stop);
+    Runnable m3 = createSleepProcessWorkFn(processStart3, stop);
+
+    // Initial state.
+    assertEquals(0, executor.activeCount());
+    assertEquals(2, executor.getMaximumPoolSize());
+
+    // m1 and m2 are accepted.
+    executor.execute(m1, 1);
+    processStart1.await();
+    assertEquals(1, executor.activeCount());
+    executor.execute(m2, 1);
+    processStart2.await();
+    assertEquals(2, executor.activeCount());
+
+    // Max pool size was reached so no new work is accepted.
+    executor.execute(m3, 1);
+    assertFalse(processStart3.await(1000, TimeUnit.MILLISECONDS));
+
+    assertEquals(0l, executor.allThreadsActiveTime());
+    stop.countDown();
+    while (executor.activeCount() != 0) {
+      // Waiting for all threads to be ended.
+      Thread.sleep(200);
+    }
+    // Max pool size was reached so the allThreadsActiveTime() was updated.
+    assertThat(executor.allThreadsActiveTime(), greaterThan(0l));
+
+    executor.shutdown();
+  }
+
+  @Test
+  public void 
testRecordTotalTimeMaxActiveThreadsUsedWhenMaximumPoolSizeUpdated() throws 
Exception {
+    CountDownLatch processStart1 = new CountDownLatch(1);
+    CountDownLatch processStart2 = new CountDownLatch(1);
+    CountDownLatch processStart3 = new CountDownLatch(1);
+    CountDownLatch stop = new CountDownLatch(1);
+    Runnable m1 = createSleepProcessWorkFn(processStart1, stop);
+    Runnable m2 = createSleepProcessWorkFn(processStart2, stop);
+    Runnable m3 = createSleepProcessWorkFn(processStart3, stop);
+
+    // Initial state.
+    assertEquals(0, executor.activeCount());
+    assertEquals(2, executor.getMaximumPoolSize());
+
+    // m1 and m2 are accepted.
+    executor.execute(m1, 1);
+    processStart1.await();
+    assertEquals(1, executor.activeCount());
+    executor.execute(m2, 1);
+    processStart2.await();
+    assertEquals(2, executor.activeCount());
+
+    // Max pool size was reached so no new work is accepted.
+    executor.execute(m3, 1);
+    assertFalse(processStart3.await(1000, TimeUnit.MILLISECONDS));
+
+    assertEquals(0l, executor.allThreadsActiveTime());
+    // Increase the max thread count
+    executor.setMaximumPoolSize(5, 105);
+    stop.countDown();
+    while (executor.activeCount() != 0) {
+      // Waiting for all threads to be ended.
+      Thread.sleep(200);
+    }
+    // Max pool size was updated during execution but allThreadsActiveTime() 
was still recorded
+    // for the thread which reached the old max pool size.
+    assertThat(executor.allThreadsActiveTime(), greaterThan(0l));
+
+    executor.shutdown();
+  }
+
+  @Test
+  public void testRenderSummaryHtml() throws Exception {
+    String expectedSummaryHtml =
+        "Worker Threads: 0/2<br>/n"
+            + "Active Threads: 0<br>/n"
+            + "Work Queue Size: 0/102<br>/n"
+            + "Work Queue Bytes: 0/10000000<br>/n";
+    assertEquals(expectedSummaryHtml, executor.summaryHtml());
+  }
+}


Reply via email to