This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new f437a783e19 Tune maximum thread count for streaming dataflow worker
executor dynamically. (#30439)
f437a783e19 is described below
commit f437a783e191592917bd417495bc6317142e6a43
Author: Melody Shen <[email protected]>
AuthorDate: Fri Apr 5 02:12:23 2024 -0700
Tune maximum thread count for streaming dataflow worker executor
dynamically. (#30439)
Workers will read the StreamignScalingReportResponse from worker messages
and
configure the executor pool size based on the specified value.
---
.../org/apache/beam/gradle/BeamModulePlugin.groovy | 2 +-
.../dataflow/worker/DataflowWorkUnitClient.java | 10 +-
.../runners/dataflow/worker/WorkUnitClient.java | 4 +-
.../harness/StreamingWorkerStatusReporter.java | 57 ++++-
.../dataflow/worker/util/BoundedQueueExecutor.java | 76 ++++--
.../worker/DataflowWorkUnitClientTest.java | 13 +-
.../harness/StreamingWorkerStatusReporterTest.java | 100 ++++++++
.../worker/util/BoundedQueueExecutorTest.java | 268 +++++++++++++++++++++
8 files changed, 503 insertions(+), 27 deletions(-)
diff --git
a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 8be8d73fbcb..d4a16325641 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -732,7 +732,7 @@ class BeamModulePlugin implements Plugin<Project> {
google_api_common :
"com.google.api:api-common", // google_cloud_platform_libraries_bom sets version
google_api_services_bigquery :
"com.google.apis:google-api-services-bigquery:v2-rev20240124-2.0.0", //
[bomupgrader] sets version
google_api_services_cloudresourcemanager :
"com.google.apis:google-api-services-cloudresourcemanager:v1-rev20240128-2.0.0",
// [bomupgrader] sets version
- google_api_services_dataflow :
"com.google.apis:google-api-services-dataflow:v1b3-rev20240113-$google_clients_version",
+ google_api_services_dataflow :
"com.google.apis:google-api-services-dataflow:v1b3-rev20240218-$google_clients_version",
google_api_services_healthcare :
"com.google.apis:google-api-services-healthcare:v1-rev20240130-$google_clients_version",
google_api_services_pubsub :
"com.google.apis:google-api-services-pubsub:v1-rev20220904-$google_clients_version",
google_api_services_storage :
"com.google.apis:google-api-services-storage:v1-rev20240205-2.0.0", //
[bomupgrader] sets version
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java
index f3caa8d0f3a..af8e7dd50c9 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java
@@ -39,6 +39,7 @@ import com.google.api.services.dataflow.model.WorkItem;
import com.google.api.services.dataflow.model.WorkItemServiceState;
import com.google.api.services.dataflow.model.WorkItemStatus;
import com.google.api.services.dataflow.model.WorkerMessage;
+import com.google.api.services.dataflow.model.WorkerMessageResponse;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
@@ -312,7 +313,8 @@ class DataflowWorkUnitClient implements WorkUnitClient {
* perworkermetrics with this path.
*/
@Override
- public void reportWorkerMessage(List<WorkerMessage> messages) throws
IOException {
+ public List<WorkerMessageResponse> reportWorkerMessage(List<WorkerMessage>
messages)
+ throws IOException {
SendWorkerMessagesRequest request =
new SendWorkerMessagesRequest()
.setLocation(options.getRegion())
@@ -327,6 +329,10 @@ class DataflowWorkUnitClient implements WorkUnitClient {
logger.warn("Worker Message response is null");
throw new IOException("Got null Worker Message response");
}
- // Currently no response is expected
+ if (result.getWorkerMessageResponses() == null) {
+ logger.debug("Worker Message response is empty.");
+ return Collections.emptyList();
+ }
+ return result.getWorkerMessageResponses();
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java
index d75d91d0088..26b1dc55ead 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java
@@ -23,6 +23,7 @@ import com.google.api.services.dataflow.model.WorkItem;
import com.google.api.services.dataflow.model.WorkItemServiceState;
import com.google.api.services.dataflow.model.WorkItemStatus;
import com.google.api.services.dataflow.model.WorkerMessage;
+import com.google.api.services.dataflow.model.WorkerMessageResponse;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
@@ -75,6 +76,7 @@ public interface WorkUnitClient {
* perworkermetrics with this path.
*
* @param msg the WorkerMessages to report
+ * @return a list of {@link WorkerMessageResponse}
*/
- void reportWorkerMessage(List<WorkerMessage> messages) throws IOException;
+ List<WorkerMessageResponse> reportWorkerMessage(List<WorkerMessage>
messages) throws IOException;
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java
index 409f0337eeb..8e950546ae6 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java
@@ -23,8 +23,10 @@ import com.google.api.services.dataflow.model.CounterUpdate;
import com.google.api.services.dataflow.model.PerStepNamespaceMetrics;
import com.google.api.services.dataflow.model.PerWorkerMetrics;
import com.google.api.services.dataflow.model.StreamingScalingReport;
+import com.google.api.services.dataflow.model.StreamingScalingReportResponse;
import com.google.api.services.dataflow.model.WorkItemStatus;
import com.google.api.services.dataflow.model.WorkerMessage;
+import com.google.api.services.dataflow.model.WorkerMessageResponse;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
@@ -34,6 +36,7 @@ import java.util.Optional;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.function.Supplier;
@@ -70,6 +73,8 @@ public final class StreamingWorkerStatusReporter {
private static final String GLOBAL_WORKER_UPDATE_REPORTER_THREAD =
"GlobalWorkerUpdates";
private final boolean publishCounters;
+ private final int initialMaxThreadCount;
+ private final int initialMaxBundlesOutstanding;
private final WorkUnitClient dataflowServiceClient;
private final Supplier<Long> windmillQuotaThrottleTime;
private final Supplier<Collection<StageInfo>> allStageInfo;
@@ -78,6 +83,7 @@ public final class StreamingWorkerStatusReporter {
private final MemoryMonitor memoryMonitor;
private final BoundedQueueExecutor workExecutor;
private final AtomicLong previousTimeAtMaxThreads;
+ private final AtomicInteger maxThreadCountOverride;
private final ScheduledExecutorService globalWorkerUpdateReporter;
private final ScheduledExecutorService workerMessageReporter;
@@ -99,7 +105,10 @@ public final class StreamingWorkerStatusReporter {
this.streamingCounters = streamingCounters;
this.memoryMonitor = memoryMonitor;
this.workExecutor = workExecutor;
+ this.initialMaxThreadCount = workExecutor.getMaximumPoolSize();
+ this.initialMaxBundlesOutstanding =
workExecutor.maximumElementsOutstanding();
this.previousTimeAtMaxThreads = new AtomicLong();
+ this.maxThreadCountOverride = new AtomicInteger();
this.globalWorkerUpdateReporter =
executorFactory.apply(GLOBAL_WORKER_UPDATE_REPORTER_THREAD);
this.workerMessageReporter =
executorFactory.apply(WORKER_MESSAGE_REPORTER_THREAD);
}
@@ -299,9 +308,12 @@ public final class StreamingWorkerStatusReporter {
}
}
- private void reportPeriodicWorkerMessage() {
+ @VisibleForTesting
+ public void reportPeriodicWorkerMessage() {
try {
- dataflowServiceClient.reportWorkerMessage(createWorkerMessage());
+ List<WorkerMessageResponse> workerMessageResponses =
+ dataflowServiceClient.reportWorkerMessage(createWorkerMessage());
+
readAndSaveWorkerMessageResponseForStreamingScalingReportResponse(workerMessageResponses);
} catch (IOException e) {
LOG.warn("Failed to send worker messages", e);
} catch (Exception e) {
@@ -346,6 +358,47 @@ public final class StreamingWorkerStatusReporter {
dataflowServiceClient.createWorkerMessageFromPerWorkerMetrics(perWorkerMetrics));
}
+ private void
readAndSaveWorkerMessageResponseForStreamingScalingReportResponse(
+ List<WorkerMessageResponse> responses) {
+ Optional<StreamingScalingReportResponse> streamingScalingReportResponse =
Optional.empty();
+ for (WorkerMessageResponse response : responses) {
+ if (response.getStreamingScalingReportResponse() != null) {
+ streamingScalingReportResponse =
Optional.of(response.getStreamingScalingReportResponse());
+ }
+ }
+ if (streamingScalingReportResponse.isPresent()) {
+ int oldMaximumThreadCount = getMaxThreads();
+
maxThreadCountOverride.set(streamingScalingReportResponse.get().getMaximumThreadCount());
+ int newMaximumThreadCount = getMaxThreads();
+ if (newMaximumThreadCount != oldMaximumThreadCount) {
+ LOG.info(
+ "Setting maximum thread count to {}, old value is {}",
+ newMaximumThreadCount,
+ oldMaximumThreadCount);
+ workExecutor.setMaximumPoolSize(newMaximumThreadCount,
getMaxBundlesOutstanding());
+ }
+ }
+ }
+
+ private int getMaxThreads() {
+ int currentMaxThreadCountOverride = maxThreadCountOverride.get();
+ if (currentMaxThreadCountOverride != 0) {
+ return currentMaxThreadCountOverride;
+ }
+ return initialMaxThreadCount;
+ }
+
+ private int getMaxBundlesOutstanding() {
+ int currentMaxThreadCountOverride = maxThreadCountOverride.get();
+ if (currentMaxThreadCountOverride != 0) {
+ return currentMaxThreadCountOverride + 100;
+ }
+ if (initialMaxBundlesOutstanding > 0) {
+ return initialMaxBundlesOutstanding;
+ }
+ return getMaxThreads() + 100;
+ }
+
@VisibleForTesting
public void reportPeriodicWorkerUpdates() {
updateVMMetrics();
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
index f7f6fd91a8c..9a481169350 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
@@ -21,7 +21,7 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.concurrent.GuardedBy;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor.Guard;
@@ -32,15 +32,26 @@ import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurren
})
public class BoundedQueueExecutor {
private final ThreadPoolExecutor executor;
- private final int maximumElementsOutstanding;
private final long maximumBytesOutstanding;
- private final int maximumPoolSize;
+ // Used to guard elementsOutstanding and bytesOutstanding.
private final Monitor monitor = new Monitor();
private int elementsOutstanding = 0;
private long bytesOutstanding = 0;
- private final AtomicInteger activeCount = new AtomicInteger();
+
+ @GuardedBy("this")
+ private int maximumElementsOutstanding;
+
+ @GuardedBy("this")
+ private int activeCount;
+
+ @GuardedBy("this")
+ private int maximumPoolSize;
+
+ @GuardedBy("this")
private long startTimeMaxActiveThreadsUsed;
+
+ @GuardedBy("this")
private long totalTimeMaxActiveThreadsUsed;
public BoundedQueueExecutor(
@@ -62,8 +73,8 @@ public class BoundedQueueExecutor {
@Override
protected void beforeExecute(Thread t, Runnable r) {
super.beforeExecute(t, r);
- synchronized (this) {
- if (activeCount.getAndIncrement() >= maximumPoolSize - 1) {
+ synchronized (BoundedQueueExecutor.this) {
+ if (++activeCount >= maximumPoolSize &&
startTimeMaxActiveThreadsUsed == 0) {
startTimeMaxActiveThreadsUsed = System.currentTimeMillis();
}
}
@@ -72,8 +83,8 @@ public class BoundedQueueExecutor {
@Override
protected void afterExecute(Runnable r, Throwable t) {
super.afterExecute(r, t);
- synchronized (this) {
- if (activeCount.getAndDecrement() == maximumPoolSize) {
+ synchronized (BoundedQueueExecutor.this) {
+ if (--activeCount < maximumPoolSize &&
startTimeMaxActiveThreadsUsed > 0) {
totalTimeMaxActiveThreadsUsed +=
(System.currentTimeMillis() -
startTimeMaxActiveThreadsUsed);
startTimeMaxActiveThreadsUsed = 0;
@@ -95,16 +106,31 @@ public class BoundedQueueExecutor {
public boolean isSatisfied() {
return elementsOutstanding == 0
|| (bytesAvailable() >= workBytes
- && elementsOutstanding < maximumElementsOutstanding);
+ && elementsOutstanding < maximumElementsOutstanding());
}
});
- executeLockHeld(work, workBytes);
+ executeMonitorHeld(work, workBytes);
}
// Forcibly add something to the queue, ignoring the length limit.
public void forceExecute(Runnable work, long workBytes) {
monitor.enter();
- executeLockHeld(work, workBytes);
+ executeMonitorHeld(work, workBytes);
+ }
+
+ // Set the maximum/core pool size of the executor.
+ public synchronized void setMaximumPoolSize(int maximumPoolSize, int
maximumElementsOutstanding) {
+ // For ThreadPoolExecutor, the maximum pool size should always greater
than or equal to core
+ // pool size.
+ if (maximumPoolSize > executor.getCorePoolSize()) {
+ executor.setMaximumPoolSize(maximumPoolSize);
+ executor.setCorePoolSize(maximumPoolSize);
+ } else {
+ executor.setCorePoolSize(maximumPoolSize);
+ executor.setMaximumPoolSize(maximumPoolSize);
+ }
+ this.maximumPoolSize = maximumPoolSize;
+ this.maximumElementsOutstanding = maximumElementsOutstanding;
}
public void shutdown() throws InterruptedException {
@@ -118,31 +144,41 @@ public class BoundedQueueExecutor {
return executor.getQueue().isEmpty();
}
- public long allThreadsActiveTime() {
+ public synchronized long allThreadsActiveTime() {
return totalTimeMaxActiveThreadsUsed;
}
- public int activeCount() {
- return activeCount.intValue();
+ public synchronized int activeCount() {
+ return activeCount;
}
public long bytesOutstanding() {
- return bytesOutstanding;
+ monitor.enter();
+ try {
+ return bytesOutstanding;
+ } finally {
+ monitor.leave();
+ }
}
public int elementsOutstanding() {
- return elementsOutstanding;
+ monitor.enter();
+ try {
+ return elementsOutstanding;
+ } finally {
+ monitor.leave();
+ }
}
public long maximumBytesOutstanding() {
return maximumBytesOutstanding;
}
- public int maximumElementsOutstanding() {
+ public synchronized int maximumElementsOutstanding() {
return maximumElementsOutstanding;
}
- public final int getMaximumPoolSize() {
+ public synchronized int getMaximumPoolSize() {
return maximumPoolSize;
}
@@ -163,7 +199,7 @@ public class BoundedQueueExecutor {
builder.append("Work Queue Size: ");
builder.append(elementsOutstanding);
builder.append("/");
- builder.append(maximumElementsOutstanding);
+ builder.append(maximumElementsOutstanding());
builder.append("<br>/n");
builder.append("Work Queue Bytes: ");
@@ -178,7 +214,7 @@ public class BoundedQueueExecutor {
}
}
- private void executeLockHeld(Runnable work, long workBytes) {
+ private void executeMonitorHeld(Runnable work, long workBytes) {
bytesOutstanding += workBytes;
++elementsOutstanding;
monitor.leave();
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java
index fac56890f49..85d79e6be3c 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java
@@ -34,10 +34,13 @@ import
com.google.api.services.dataflow.model.SendWorkerMessagesRequest;
import com.google.api.services.dataflow.model.SendWorkerMessagesResponse;
import com.google.api.services.dataflow.model.SeqMapTask;
import com.google.api.services.dataflow.model.StreamingScalingReport;
+import com.google.api.services.dataflow.model.StreamingScalingReportResponse;
import com.google.api.services.dataflow.model.WorkItem;
import com.google.api.services.dataflow.model.WorkerMessage;
+import com.google.api.services.dataflow.model.WorkerMessageResponse;
import java.io.IOException;
import java.util.Collections;
+import java.util.List;
import java.util.Optional;
import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions;
import
org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC;
@@ -253,6 +256,12 @@ public class DataflowWorkUnitClientTest {
MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
response.setContentType(Json.MEDIA_TYPE);
SendWorkerMessagesResponse workerMessage = new
SendWorkerMessagesResponse();
+ StreamingScalingReportResponse streamingScalingReportResponse =
+ new StreamingScalingReportResponse().setMaximumThreadCount(10);
+ WorkerMessageResponse workerMessageResponse =
+ new WorkerMessageResponse()
+ .setStreamingScalingReportResponse(streamingScalingReportResponse);
+
workerMessage.setWorkerMessageResponses(Collections.singletonList(workerMessageResponse));
workerMessage.setFactory(Transport.getJsonFactory());
response.setContent(workerMessage.toPrettyString());
@@ -271,12 +280,14 @@ public class DataflowWorkUnitClientTest {
.setMaximumBundleCount(5)
.setMaximumBytes(6L);
WorkerMessage msg =
client.createWorkerMessageFromStreamingScalingReport(activeThreadsReport);
- client.reportWorkerMessage(Collections.singletonList(msg));
+ List<WorkerMessageResponse> responses =
+ client.reportWorkerMessage(Collections.singletonList(msg));
SendWorkerMessagesRequest actualRequest =
Transport.getJsonFactory()
.fromString(request.getContentAsString(),
SendWorkerMessagesRequest.class);
assertEquals(ImmutableList.of(msg), actualRequest.getWorkerMessages());
+ assertEquals(ImmutableList.of(workerMessageResponse), responses);
}
@Test
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java
new file mode 100644
index 00000000000..bdf0f0031d6
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming.harness;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.api.services.dataflow.model.StreamingScalingReportResponse;
+import com.google.api.services.dataflow.model.WorkerMessageResponse;
+import java.util.Collections;
+import java.util.concurrent.Executors;
+import org.apache.beam.runners.dataflow.worker.WorkUnitClient;
+import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
+import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mockito;
+
+@RunWith(JUnit4.class)
+public class StreamingWorkerStatusReporterTest {
+ private final long DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME = 1000;
+
+ private BoundedQueueExecutor mockExecutor;
+ private WorkUnitClient mockWorkUnitClient;
+ private FailureTracker mockFailureTracker;
+ private MemoryMonitor mockMemoryMonitor;
+
+ @Before
+ public void setUp() {
+ this.mockExecutor = mock(BoundedQueueExecutor.class);
+ this.mockWorkUnitClient = mock(WorkUnitClient.class);
+ this.mockFailureTracker = mock(FailureTracker.class);
+ this.mockMemoryMonitor = mock(MemoryMonitor.class);
+ }
+
+ @Test
+ public void testOverrideMaximumThreadCount() throws Exception {
+ StreamingWorkerStatusReporter reporter =
+ StreamingWorkerStatusReporter.forTesting(
+ true,
+ mockWorkUnitClient,
+ () -> DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME,
+ () -> Collections.emptyList(),
+ mockFailureTracker,
+ StreamingCounters.create(),
+ mockMemoryMonitor,
+ mockExecutor,
+ (threadName) -> Executors.newSingleThreadScheduledExecutor());
+ StreamingScalingReportResponse streamingScalingReportResponse =
+ new StreamingScalingReportResponse().setMaximumThreadCount(10);
+ WorkerMessageResponse workerMessageResponse =
+ new WorkerMessageResponse()
+ .setStreamingScalingReportResponse(streamingScalingReportResponse);
+ when(mockWorkUnitClient.reportWorkerMessage(any()))
+ .thenReturn(Collections.singletonList(workerMessageResponse));
+ reporter.reportPeriodicWorkerMessage();
+ verify(mockExecutor).setMaximumPoolSize(10, 110);
+ }
+
+ @Test
+ public void testHandleEmptyWorkerMessageResponse() throws Exception {
+ StreamingWorkerStatusReporter reporter =
+ StreamingWorkerStatusReporter.forTesting(
+ true,
+ mockWorkUnitClient,
+ () -> DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME,
+ () -> Collections.emptyList(),
+ mockFailureTracker,
+ StreamingCounters.create(),
+ mockMemoryMonitor,
+ mockExecutor,
+ (threadName) -> Executors.newSingleThreadScheduledExecutor());
+ WorkerMessageResponse workerMessageResponse = new WorkerMessageResponse();
+ when(mockWorkUnitClient.reportWorkerMessage(any()))
+ .thenReturn(Collections.singletonList(workerMessageResponse));
+ reporter.reportPeriodicWorkerMessage();
+ verify(mockExecutor, Mockito.times(0)).setMaximumPoolSize(anyInt(),
anyInt());
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
new file mode 100644
index 00000000000..c0620952ef9
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.util;
+
+import static org.hamcrest.Matchers.greaterThan;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link
org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor}. */
+@RunWith(JUnit4.class)
+// TODO(https://github.com/apache/beam/issues/21230): Remove when new version
of errorprone is
+// released (2.11.0)
+@SuppressWarnings("unused")
+public class BoundedQueueExecutorTest {
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(300);
+ private static final long MAXIMUM_BYTES_OUTSTANDING = 10000000;
+ private static final int DEFAULT_MAX_THREADS = 2;
+ private static final int DEFAULT_THREAD_EXPIRATION_SEC = 60;
+
+ private BoundedQueueExecutor executor;
+
+ private Runnable createSleepProcessWorkFn(CountDownLatch start,
CountDownLatch stop) {
+ Runnable runnable =
+ () -> {
+ start.countDown();
+ try {
+ stop.await();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ };
+ return runnable;
+ }
+
+ @Before
+ public void setUp() {
+ this.executor =
+ new BoundedQueueExecutor(
+ DEFAULT_MAX_THREADS,
+ DEFAULT_THREAD_EXPIRATION_SEC,
+ TimeUnit.SECONDS,
+ DEFAULT_MAX_THREADS + 100,
+ MAXIMUM_BYTES_OUTSTANDING,
+ new ThreadFactoryBuilder()
+ .setNameFormat("DataflowWorkUnits-%d")
+ .setDaemon(true)
+ .build());
+ }
+
+ @Test
+ public void testScheduleWorkWhenExceedMaximumPoolSize() throws Exception {
+ CountDownLatch processStart1 = new CountDownLatch(1);
+ CountDownLatch processStop1 = new CountDownLatch(1);
+ CountDownLatch processStart2 = new CountDownLatch(1);
+ CountDownLatch processStop2 = new CountDownLatch(1);
+ CountDownLatch processStart3 = new CountDownLatch(1);
+ CountDownLatch processStop3 = new CountDownLatch(1);
+ Runnable m1 = createSleepProcessWorkFn(processStart1, processStop1);
+ Runnable m2 = createSleepProcessWorkFn(processStart2, processStop2);
+ Runnable m3 = createSleepProcessWorkFn(processStart3, processStop3);
+
+ executor.execute(m1, 1);
+ processStart1.await();
+ executor.execute(m2, 1);
+ processStart2.await();
+ // m1 and m2 have started and all threads are occupied so m3 will be
queued and not executed.
+ executor.execute(m3, 1);
+ assertFalse(processStart3.await(1000, TimeUnit.MILLISECONDS));
+ assertFalse(executor.executorQueueIsEmpty());
+
+ // Stop m1 so there is an available thread for m3 to run.
+ processStop1.countDown();
+ processStart3.await();
+ // m3 started.
+ assertTrue(executor.executorQueueIsEmpty());
+ processStop2.countDown();
+ processStop3.countDown();
+ executor.shutdown();
+ }
+
+ @Test
+ public void testScheduleWorkWhenExceedMaximumBytesOutstanding() throws
Exception {
+ CountDownLatch processStart1 = new CountDownLatch(1);
+ CountDownLatch processStop1 = new CountDownLatch(1);
+ CountDownLatch processStart2 = new CountDownLatch(1);
+ CountDownLatch processStop2 = new CountDownLatch(1);
+ Runnable m1 = createSleepProcessWorkFn(processStart1, processStop1);
+ Runnable m2 = createSleepProcessWorkFn(processStart2, processStop2);
+
+ executor.execute(m1, 10000000);
+ processStart1.await();
+ // m1 has started and reached the maximumBytesOutstanding. Though the
executor has available
+ // threads, the new task will be blocked until the bytes are available.
+ // Start a new thread since executor.execute() is a blocking function.
+ Thread m2Runner =
+ new Thread(
+ () -> {
+ executor.execute(m2, 1000);
+ });
+ m2Runner.start();
+ assertFalse(processStart2.await(1000, TimeUnit.MILLISECONDS));
+ // m2 will wait for monitor instead of being queued.
+ assertEquals(Thread.State.WAITING, m2Runner.getState());
+ assertTrue(executor.executorQueueIsEmpty());
+
+ // Stop m1 so there are available bytes for m2 to run.
+ processStop1.countDown();
+ processStart2.await();
+ // m2 started.
+ assertEquals(Thread.State.TERMINATED, m2Runner.getState());
+ processStop2.countDown();
+ executor.shutdown();
+ }
+
+ @Test
+ public void testOverrideMaximumPoolSize() throws Exception {
+ CountDownLatch processStart1 = new CountDownLatch(1);
+ CountDownLatch processStart2 = new CountDownLatch(1);
+ CountDownLatch processStart3 = new CountDownLatch(1);
+ CountDownLatch stop = new CountDownLatch(1);
+ Runnable m1 = createSleepProcessWorkFn(processStart1, stop);
+ Runnable m2 = createSleepProcessWorkFn(processStart2, stop);
+ Runnable m3 = createSleepProcessWorkFn(processStart3, stop);
+
+ // Initial state.
+ assertEquals(0, executor.activeCount());
+ assertEquals(2, executor.getMaximumPoolSize());
+
+ // m1 and m2 are accepted.
+ executor.execute(m1, 1);
+ processStart1.await();
+ assertEquals(1, executor.activeCount());
+ executor.execute(m2, 1);
+ processStart2.await();
+ assertEquals(2, executor.activeCount());
+
+ // Max pool size was reached so new work is queued.
+ executor.execute(m3, 1);
+ assertFalse(processStart3.await(1000, TimeUnit.MILLISECONDS));
+
+ // Increase the max thread count
+ executor.setMaximumPoolSize(3, 103);
+ assertEquals(3, executor.getMaximumPoolSize());
+
+ // m3 is accepted
+ processStart3.await();
+ assertEquals(3, executor.activeCount());
+
+ stop.countDown();
+ executor.shutdown();
+ }
+
+ @Test
+ public void testRecordTotalTimeMaxActiveThreadsUsed() throws Exception {
+ CountDownLatch processStart1 = new CountDownLatch(1);
+ CountDownLatch processStart2 = new CountDownLatch(1);
+ CountDownLatch processStart3 = new CountDownLatch(1);
+ CountDownLatch stop = new CountDownLatch(1);
+ Runnable m1 = createSleepProcessWorkFn(processStart1, stop);
+ Runnable m2 = createSleepProcessWorkFn(processStart2, stop);
+ Runnable m3 = createSleepProcessWorkFn(processStart3, stop);
+
+ // Initial state.
+ assertEquals(0, executor.activeCount());
+ assertEquals(2, executor.getMaximumPoolSize());
+
+ // m1 and m2 are accepted.
+ executor.execute(m1, 1);
+ processStart1.await();
+ assertEquals(1, executor.activeCount());
+ executor.execute(m2, 1);
+ processStart2.await();
+ assertEquals(2, executor.activeCount());
+
+ // Max pool size was reached so no new work is accepted.
+ executor.execute(m3, 1);
+ assertFalse(processStart3.await(1000, TimeUnit.MILLISECONDS));
+
+ assertEquals(0l, executor.allThreadsActiveTime());
+ stop.countDown();
+ while (executor.activeCount() != 0) {
+ // Waiting for all threads to be ended.
+ Thread.sleep(200);
+ }
+ // Max pool size was reached so the allThreadsActiveTime() was updated.
+ assertThat(executor.allThreadsActiveTime(), greaterThan(0l));
+
+ executor.shutdown();
+ }
+
+ @Test
+ public void
testRecordTotalTimeMaxActiveThreadsUsedWhenMaximumPoolSizeUpdated() throws
Exception {
+ CountDownLatch processStart1 = new CountDownLatch(1);
+ CountDownLatch processStart2 = new CountDownLatch(1);
+ CountDownLatch processStart3 = new CountDownLatch(1);
+ CountDownLatch stop = new CountDownLatch(1);
+ Runnable m1 = createSleepProcessWorkFn(processStart1, stop);
+ Runnable m2 = createSleepProcessWorkFn(processStart2, stop);
+ Runnable m3 = createSleepProcessWorkFn(processStart3, stop);
+
+ // Initial state.
+ assertEquals(0, executor.activeCount());
+ assertEquals(2, executor.getMaximumPoolSize());
+
+ // m1 and m2 are accepted.
+ executor.execute(m1, 1);
+ processStart1.await();
+ assertEquals(1, executor.activeCount());
+ executor.execute(m2, 1);
+ processStart2.await();
+ assertEquals(2, executor.activeCount());
+
+ // Max pool size was reached so no new work is accepted.
+ executor.execute(m3, 1);
+ assertFalse(processStart3.await(1000, TimeUnit.MILLISECONDS));
+
+ assertEquals(0l, executor.allThreadsActiveTime());
+ // Increase the max thread count
+ executor.setMaximumPoolSize(5, 105);
+ stop.countDown();
+ while (executor.activeCount() != 0) {
+ // Waiting for all threads to be ended.
+ Thread.sleep(200);
+ }
+ // Max pool size was updated during execution but allThreadsActiveTime()
was still recorded
+ // for the thread which reached the old max pool size.
+ assertThat(executor.allThreadsActiveTime(), greaterThan(0l));
+
+ executor.shutdown();
+ }
+
+ @Test
+ public void testRenderSummaryHtml() throws Exception {
+ String expectedSummaryHtml =
+ "Worker Threads: 0/2<br>/n"
+ + "Active Threads: 0<br>/n"
+ + "Work Queue Size: 0/102<br>/n"
+ + "Work Queue Bytes: 0/10000000<br>/n";
+ assertEquals(expectedSummaryHtml, executor.summaryHtml());
+ }
+}