This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 93fb2045ea1 Add interfaces for direct path, and StreamingEngineClient
(#28835)
93fb2045ea1 is described below
commit 93fb2045ea1ddf42e5da49048302390a2db6e492
Author: martin trieu <[email protected]>
AuthorDate: Thu Nov 30 02:19:13 2023 -0800
Add interfaces for direct path, and StreamingEngineClient (#28835)
Add interfaces/classes for direct path:
ProcessWorkItemClient
Exposed to WorkItemProcessor to give access to route GetData and CommitWork
stream RPCs to the same workers where GetWork was called (currently
StreamingDataflowWorker#process).
WorkItemProcessor
Replaces WorkItemReceiver, same but takes and exposes ProcessWorkItemClient
instead of WorkItem. Since ProcessWorkItemClient needs a way to get data for
work, refresh work, get side input data, and commit work, the place where its
created (GrpcGetWorkStream) needs to be modified to accept GetDataStream (for
keyed/state data), GetDataStream (global side input data), and a
CommitWorkStream.
GetWorkBudget
A struct to model item and byte budgets for how much work a user worker can
handle. This is passed in GetWorkRequest(s) to Windmill to control how many
items/bytes of Work is returned.
GetWorkBudgetDistributor
Given a set of WindmillStreamSender(s) and GetWorkBudget, distributes the
budgets to the WindmillStreamSender(s) in some manner.
EvenGetWorkBudgetDistributor
GetWorkBudgetDistributor implementation that distributes the budget evenly
WindmillStreamSender
When the Grpc*Stream(s) are created, they immediately start the underlying
grpc stream (startStream is called, and has protected access). To be able to
assign budgets and get the streams ready to be started (similar to
GetWorkClientSender), WindmillStreamSender wraps the 3 WorkItem API RPC
streams, and exposes a startStream, and closeAllStreams to manage the
underlying streams. Once the streams are started they are cached (via thread
safe memoization). Once certain endpoints are stale [...]
DispatcherClient
Manages/vends out stubs and the dispatcher
Thread safe via synchronization on reads and writes.
Add StreamingEngineClient
Manages the available backend Windmill workers via GetWorkerMetadata. We
never close this stream. WorkerMetadata updates are then submitted to a single
threaded executor which will consume it, and update StreamingEngineClient
internal connections state
Given a total budget, divides it amongst the available backend Windmill
workers (represented as Endpoints, Connections, and WindmillStreamSenders)
starts GetWorkStream(s). Closes streams via
WindmillStreamSender#closeAllStreams when the endpoint for the stream is not
available in updated worker metadata.
Contains single threaded executor for triggered budget refreshes. Budget
refreshes are triggered when new worker metadata is consumed (implemented),
work has completed processing (either has been committed back to windmill or
put in an un-active state). Uses a SynchronousQueue to implement a
publish/subscribe pattern. put blocks until another thread take(s) from the
queue.
Contains single threaded executor for periodic budget refreshes.
Future changes need still:
Have GrpcGetWorkStream accept a GrpcGetDataStream and GrpcCommitWorkStream
so that it can construct a ProcessWorkItemClient and pass it onto the
ProcessWorkItem (replaces current behavior whereWorkItem being passed to the
WorkItemReceiver).
Integrate with StreamingDataflowWorker, might be worth having 2 different
implementations of StreamingDataflowWorker since current
MetricTrackingWindmillServer is used for GetData (keyed and global) fetches.
Need to figure out a way to batch commits since they need to go to the same
origin worker
---
.../worker/windmill/WindmillConnection.java | 57 +++
.../client/grpc/GetWorkTimingInfosTracker.java | 10 +-
.../windmill/client/grpc/GrpcCommitWorkStream.java | 4 +-
.../client/grpc/GrpcDirectGetWorkStream.java | 320 ++++++++++++++++
.../windmill/client/grpc/GrpcGetDataStream.java | 4 +-
.../client/grpc/GrpcGetWorkerMetadataStream.java | 3 +-
.../client/grpc/GrpcWindmillStreamFactory.java | 28 +-
.../client/grpc/StreamingEngineClient.java | 401 ++++++++++++++++++++
.../grpc/StreamingEngineConnectionState.java | 64 ++++
.../windmill/client/grpc/WindmillStreamSender.java | 156 ++++++++
.../windmill/work/ProcessWorkItemClient.java | 52 +++
.../worker/windmill/work/WorkItemProcessor.java | 57 +++
.../work/budget/EvenGetWorkBudgetDistributor.java | 101 +++++
.../worker/windmill/work/budget/GetWorkBudget.java | 35 +-
.../work/budget/GetWorkBudgetDistributor.java | 33 ++
.../work/budget/GetWorkBudgetDistributors.java | 29 ++
.../work/budget/GetWorkBudgetRefresher.java | 133 +++++++
.../client/grpc/StreamingEngineClientTest.java | 417 +++++++++++++++++++++
.../client/grpc/WindmillStreamSenderTest.java | 239 ++++++++++++
.../budget/EvenGetWorkBudgetDistributorTest.java | 265 +++++++++++++
.../work/budget/GetWorkBudgetRefresherTest.java | 102 +++++
.../windmill/work/budget/GetWorkBudgetTest.java | 26 +-
22 files changed, 2482 insertions(+), 54 deletions(-)
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java
new file mode 100644
index 00000000000..e49a04a7a54
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill;
+
+import com.google.auto.value.AutoValue;
+import java.util.Optional;
+import java.util.function.Function;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import org.apache.beam.sdk.annotations.Internal;
+
+@AutoValue
+@Internal
+public abstract class WindmillConnection {
+ public static WindmillConnection from(
+ Endpoint windmillEndpoint,
+ Function<Endpoint, CloudWindmillServiceV1Alpha1Stub> endpointToStubFn) {
+ WindmillConnection.Builder windmillWorkerConnection =
WindmillConnection.builder();
+
+
windmillEndpoint.workerToken().ifPresent(windmillWorkerConnection::setBackendWorkerToken);
+ windmillWorkerConnection.setStub(endpointToStubFn.apply(windmillEndpoint));
+
+ return windmillWorkerConnection.build();
+ }
+
+ public static Builder builder() {
+ return new AutoValue_WindmillConnection.Builder();
+ }
+
+ public abstract Optional<String> backendWorkerToken();
+
+ public abstract CloudWindmillServiceV1Alpha1Stub stub();
+
+ @AutoValue.Builder
+ abstract static class Builder {
+ abstract Builder setBackendWorkerToken(String backendWorkerToken);
+
+ abstract Builder setStub(CloudWindmillServiceV1Alpha1Stub stub);
+
+ abstract WindmillConnection build();
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GetWorkTimingInfosTracker.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GetWorkTimingInfosTracker.java
index 221b18be164..dc3486d743a 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GetWorkTimingInfosTracker.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GetWorkTimingInfosTracker.java
@@ -35,8 +35,7 @@ import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-class GetWorkTimingInfosTracker {
-
+final class GetWorkTimingInfosTracker {
private static final Logger LOG =
LoggerFactory.getLogger(GetWorkTimingInfosTracker.class);
private final Map<State, SumAndMaxDurations>
aggregatedGetWorkStreamLatencies;
@@ -53,7 +52,7 @@ class GetWorkTimingInfosTracker {
workItemCreationLatency = null;
}
- public void addTimingInfo(Collection<GetWorkStreamTimingInfo> infos) {
+ void addTimingInfo(Collection<GetWorkStreamTimingInfo> infos) {
// We want to record duration for each stage and also be reflective on
total work item
// processing time. It can be tricky because timings of different
// StreamingGetWorkResponseChunks can be interleaved. Current strategy is
to record the
@@ -170,7 +169,7 @@ class GetWorkTimingInfosTracker {
return latencyAttributions;
}
- public void reset() {
+ void reset() {
this.aggregatedGetWorkStreamLatencies.clear();
this.workItemCreationEndTime = Instant.EPOCH;
this.workItemLastChunkReceivedByWorkerTime = Instant.EPOCH;
@@ -178,11 +177,10 @@ class GetWorkTimingInfosTracker {
}
private static class SumAndMaxDurations {
-
private Duration sum;
private Duration max;
- public SumAndMaxDurations(Duration sum, Duration max) {
+ private SumAndMaxDurations(Duration sum, Duration max) {
this.sum = sum;
this.max = max;
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
index 5d0a5085fe1..9350b89f182 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
@@ -43,7 +43,7 @@ import
org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-final class GrpcCommitWorkStream
+public final class GrpcCommitWorkStream
extends AbstractWindmillStream<StreamingCommitWorkRequest,
StreamingCommitResponse>
implements CommitWorkStream {
private static final Logger LOG =
LoggerFactory.getLogger(GrpcCommitWorkStream.class);
@@ -82,7 +82,7 @@ final class GrpcCommitWorkStream
this.streamingRpcBatchLimit = streamingRpcBatchLimit;
}
- static GrpcCommitWorkStream create(
+ public static GrpcCommitWorkStream create(
Function<StreamObserver<StreamingCommitResponse>,
StreamObserver<StreamingCommitWorkRequest>>
startCommitWorkRpcFn,
BackOff backoff,
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
new file mode 100644
index 00000000000..683f94eb71e
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
@@ -0,0 +1,320 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import com.google.auto.value.AutoValue;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationWorkItemMetadata;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItemClient;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.util.BackOff;
+import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Implementation of {@link GetWorkStream} that passes along a specific {@link
+ *
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream}
and {@link
+ *
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream}
to the
+ * processing context {@link ProcessWorkItemClient}. During the work item
processing lifecycle,
+ * these direct streams are used to facilitate these RPC calls to specific
backend workers.
+ */
+@Internal
+public final class GrpcDirectGetWorkStream
+ extends AbstractWindmillStream<StreamingGetWorkRequest,
StreamingGetWorkResponseChunk>
+ implements GetWorkStream {
+ private static final Logger LOG =
LoggerFactory.getLogger(GrpcDirectGetWorkStream.class);
+ private static final StreamingGetWorkRequest HEALTH_CHECK_REQUEST =
+ StreamingGetWorkRequest.newBuilder()
+ .setRequestExtension(
+ Windmill.StreamingGetWorkRequestExtension.newBuilder()
+ .setMaxItems(0)
+ .setMaxBytes(0)
+ .build())
+ .build();
+
+ private final AtomicReference<GetWorkBudget> inFlightBudget;
+ private final AtomicReference<GetWorkBudget> nextBudgetAdjustment;
+ private final AtomicReference<GetWorkBudget> pendingResponseBudget;
+ private final GetWorkRequest request;
+ private final WorkItemProcessor workItemProcessorFn;
+ private final ThrottleTimer getWorkThrottleTimer;
+ private final Supplier<GetDataStream> getDataStream;
+ private final Supplier<CommitWorkStream> commitWorkStream;
+ /**
+ * Map of stream IDs to their buffers. Used to aggregate streaming gRPC
response chunks as they
+ * come in. Once all chunks for a response has been received, the chunk is
processed and the
+ * buffer is cleared.
+ */
+ private final ConcurrentMap<Long, WorkItemBuffer> workItemBuffers;
+
+ private GrpcDirectGetWorkStream(
+ Function<
+ StreamObserver<StreamingGetWorkResponseChunk>,
+ StreamObserver<StreamingGetWorkRequest>>
+ startGetWorkRpcFn,
+ GetWorkRequest request,
+ BackOff backoff,
+ StreamObserverFactory streamObserverFactory,
+ Set<AbstractWindmillStream<?, ?>> streamRegistry,
+ int logEveryNStreamFailures,
+ ThrottleTimer getWorkThrottleTimer,
+ Supplier<GetDataStream> getDataStream,
+ Supplier<CommitWorkStream> commitWorkStream,
+ WorkItemProcessor workItemProcessorFn) {
+ super(
+ startGetWorkRpcFn, backoff, streamObserverFactory, streamRegistry,
logEveryNStreamFailures);
+ this.request = request;
+ this.getWorkThrottleTimer = getWorkThrottleTimer;
+ this.workItemProcessorFn = workItemProcessorFn;
+ this.workItemBuffers = new ConcurrentHashMap<>();
+ // Use the same GetDataStream and CommitWorkStream instances to process
all the work in this
+ // stream.
+ this.getDataStream = Suppliers.memoize(getDataStream::get);
+ this.commitWorkStream = Suppliers.memoize(commitWorkStream::get);
+ this.inFlightBudget = new AtomicReference<>(GetWorkBudget.noBudget());
+ this.nextBudgetAdjustment = new
AtomicReference<>(GetWorkBudget.noBudget());
+ this.pendingResponseBudget = new
AtomicReference<>(GetWorkBudget.noBudget());
+ }
+
+ public static GrpcDirectGetWorkStream create(
+ Function<
+ StreamObserver<StreamingGetWorkResponseChunk>,
+ StreamObserver<StreamingGetWorkRequest>>
+ startGetWorkRpcFn,
+ GetWorkRequest request,
+ BackOff backoff,
+ StreamObserverFactory streamObserverFactory,
+ Set<AbstractWindmillStream<?, ?>> streamRegistry,
+ int logEveryNStreamFailures,
+ ThrottleTimer getWorkThrottleTimer,
+ Supplier<GetDataStream> getDataStream,
+ Supplier<CommitWorkStream> commitWorkStream,
+ WorkItemProcessor workItemProcessorFn) {
+ GrpcDirectGetWorkStream getWorkStream =
+ new GrpcDirectGetWorkStream(
+ startGetWorkRpcFn,
+ request,
+ backoff,
+ streamObserverFactory,
+ streamRegistry,
+ logEveryNStreamFailures,
+ getWorkThrottleTimer,
+ getDataStream,
+ commitWorkStream,
+ workItemProcessorFn);
+ getWorkStream.startStream();
+ return getWorkStream;
+ }
+
+ private synchronized GetWorkBudget getThenResetBudgetAdjustment() {
+ return nextBudgetAdjustment.getAndUpdate(unused ->
GetWorkBudget.noBudget());
+ }
+
+ private void sendRequestExtension() {
+ // Just sent the request extension, reset the nextBudgetAdjustment. This
will be set when
+ // adjustBudget is called.
+ GetWorkBudget adjustment = getThenResetBudgetAdjustment();
+ StreamingGetWorkRequest extension =
+ StreamingGetWorkRequest.newBuilder()
+ .setRequestExtension(
+ Windmill.StreamingGetWorkRequestExtension.newBuilder()
+ .setMaxItems(adjustment.items())
+ .setMaxBytes(adjustment.bytes()))
+ .build();
+
+ executor()
+ .execute(
+ () -> {
+ try {
+ send(extension);
+ } catch (IllegalStateException e) {
+ // Stream was closed.
+ }
+ });
+ }
+
+ @Override
+ protected synchronized void onNewStream() {
+ workItemBuffers.clear();
+ // Add the current in-flight budget to the next adjustment. Only positive
values are allowed
+ // here
+ // with negatives defaulting to 0, since GetWorkBudgets cannot be created
with negative values.
+ GetWorkBudget budgetAdjustment =
nextBudgetAdjustment.get().apply(inFlightBudget.get());
+ inFlightBudget.set(budgetAdjustment);
+ send(
+ StreamingGetWorkRequest.newBuilder()
+ .setRequest(
+ request
+ .toBuilder()
+ .setMaxBytes(budgetAdjustment.bytes())
+ .setMaxItems(budgetAdjustment.items()))
+ .build());
+
+ // We just sent the budget, reset it.
+ nextBudgetAdjustment.set(GetWorkBudget.noBudget());
+ }
+
+ @Override
+ protected boolean hasPendingRequests() {
+ return false;
+ }
+
+ @Override
+ public void appendSpecificHtml(PrintWriter writer) {
+ // Number of buffers is same as distinct workers that sent work on this
stream.
+ writer.format(
+ "GetWorkStream: %d buffers, %s inflight budget allowed.",
+ workItemBuffers.size(), inFlightBudget.get());
+ }
+
+ @Override
+ public void sendHealthCheck() {
+ send(HEALTH_CHECK_REQUEST);
+ }
+
+ @Override
+ protected void onResponse(StreamingGetWorkResponseChunk chunk) {
+ getWorkThrottleTimer.stop();
+ WorkItemBuffer workItemBuffer =
+ workItemBuffers.computeIfAbsent(chunk.getStreamId(), unused -> new
WorkItemBuffer());
+ workItemBuffer.append(chunk);
+
+ // The entire WorkItem has been received, it is ready to be processed.
+ if (chunk.getRemainingBytesForWorkItem() == 0) {
+ workItemBuffer.runAndReset();
+ // Record the fact that there are now fewer outstanding messages and
bytes on the stream.
+ inFlightBudget.updateAndGet(budget -> budget.subtract(1,
workItemBuffer.bufferedSize()));
+ }
+ }
+
+ @Override
+ protected void startThrottleTimer() {
+ getWorkThrottleTimer.start();
+ }
+
+ @Override
+ public synchronized void adjustBudget(long itemsDelta, long bytesDelta) {
+ nextBudgetAdjustment.set(nextBudgetAdjustment.get().apply(itemsDelta,
bytesDelta));
+ sendRequestExtension();
+ }
+
+ @Override
+ public GetWorkBudget remainingBudget() {
+ // Snapshot the current budgets.
+ GetWorkBudget currentPendingResponseBudget = pendingResponseBudget.get();
+ GetWorkBudget currentNextBudgetAdjustment = nextBudgetAdjustment.get();
+ GetWorkBudget currentInflightBudget = inFlightBudget.get();
+
+ return currentPendingResponseBudget
+ .apply(currentNextBudgetAdjustment)
+ .apply(currentInflightBudget);
+ }
+
+ private synchronized void updatePendingResponseBudget(long itemsDelta, long
bytesDelta) {
+ pendingResponseBudget.set(pendingResponseBudget.get().apply(itemsDelta,
bytesDelta));
+ }
+
+ @AutoValue
+ abstract static class ComputationMetadata {
+ private static ComputationMetadata fromProto(ComputationWorkItemMetadata
metadataProto) {
+ return new AutoValue_GrpcDirectGetWorkStream_ComputationMetadata(
+ metadataProto.getComputationId(),
+
WindmillTimeUtils.windmillToHarnessWatermark(metadataProto.getInputDataWatermark()),
+ WindmillTimeUtils.windmillToHarnessWatermark(
+ metadataProto.getDependentRealtimeInputWatermark()));
+ }
+
+ abstract String computationId();
+
+ abstract Instant inputDataWatermark();
+
+ abstract Instant synchronizedProcessingTime();
+ }
+
+ private class WorkItemBuffer {
+ private final GetWorkTimingInfosTracker workTimingInfosTracker;
+ private ByteString data;
+ private @Nullable ComputationMetadata metadata;
+
+ private WorkItemBuffer() {
+ workTimingInfosTracker = new
GetWorkTimingInfosTracker(System::currentTimeMillis);
+ data = ByteString.EMPTY;
+ this.metadata = null;
+ }
+
+ private void append(StreamingGetWorkResponseChunk chunk) {
+ if (chunk.hasComputationMetadata()) {
+ this.metadata =
ComputationMetadata.fromProto(chunk.getComputationMetadata());
+ }
+
+ this.data = data.concat(chunk.getSerializedWorkItem());
+
workTimingInfosTracker.addTimingInfo(chunk.getPerWorkItemTimingInfosList());
+ }
+
+ private long bufferedSize() {
+ return data.size();
+ }
+
+ private void runAndReset() {
+ try {
+ WorkItem workItem = WorkItem.parseFrom(data.newInput());
+ updatePendingResponseBudget(1, workItem.getSerializedSize());
+ Preconditions.checkNotNull(metadata);
+ workItemProcessorFn.processWork(
+ metadata.computationId(),
+ metadata.inputDataWatermark(),
+ metadata.synchronizedProcessingTime(),
+ ProcessWorkItemClient.create(
+ WorkItem.parseFrom(data.newInput()), getDataStream.get(),
commitWorkStream.get()),
+ // After the work item is successfully queued or dropped by
ActiveWorkState, remove it
+ // from the pendingResponseBudget.
+ queuedWorkItem -> updatePendingResponseBudget(-1,
-workItem.getSerializedSize()),
+ workTimingInfosTracker.getLatencyAttributions());
+ } catch (IOException e) {
+ LOG.error("Failed to parse work item from stream: ", e);
+ }
+ workTimingInfosTracker.reset();
+ data = ByteString.EMPTY;
+ }
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
index ea9cd7f0fa3..a04a961ca9c 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
@@ -53,7 +53,7 @@ import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-final class GrpcGetDataStream
+public final class GrpcGetDataStream
extends AbstractWindmillStream<StreamingGetDataRequest,
StreamingGetDataResponse>
implements GetDataStream {
private static final Logger LOG =
LoggerFactory.getLogger(GrpcGetDataStream.class);
@@ -86,7 +86,7 @@ final class GrpcGetDataStream
this.pending = new ConcurrentHashMap<>();
}
- static GrpcGetDataStream create(
+ public static GrpcGetDataStream create(
Function<StreamObserver<StreamingGetDataResponse>,
StreamObserver<StreamingGetDataRequest>>
startGetDataRpcFn,
BackOff backoff,
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java
index a403feddb45..35524dbd2ee 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java
@@ -36,7 +36,7 @@ import
org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-final class GrpcGetWorkerMetadataStream
+public final class GrpcGetWorkerMetadataStream
extends AbstractWindmillStream<WorkerMetadataRequest,
WorkerMetadataResponse>
implements GetWorkerMetadataStream {
private static final Logger LOG =
LoggerFactory.getLogger(GrpcGetWorkerMetadataStream.class);
@@ -100,6 +100,7 @@ final class GrpcGetWorkerMetadataStream
metadataVersion,
getWorkerMetadataThrottleTimer,
serverMappingUpdater);
+ LOG.info("Started GetWorkerMetadataStream. {}", getWorkerMetadataStream);
getWorkerMetadataStream.startStream();
return getWorkerMetadataStream;
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
index e474ebf18b2..099be8db0fd 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
@@ -42,7 +42,9 @@ import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.Ge
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor;
import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver;
+import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.FluentBackoff;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
@@ -54,7 +56,8 @@ import org.joda.time.Instant;
* RPC streams for health check/heartbeat requests to keep the streams alive.
*/
@ThreadSafe
-public final class GrpcWindmillStreamFactory implements StatusDataProvider {
+@Internal
+public class GrpcWindmillStreamFactory implements StatusDataProvider {
private static final Duration MIN_BACKOFF = Duration.millis(1);
private static final Duration DEFAULT_MAX_BACKOFF =
Duration.standardSeconds(30);
private static final int DEFAULT_LOG_EVERY_N_STREAM_FAILURES = 1;
@@ -128,6 +131,26 @@ public final class GrpcWindmillStreamFactory implements
StatusDataProvider {
processWorkItem);
}
+ public GetWorkStream createDirectGetWorkStream(
+ CloudWindmillServiceV1Alpha1Stub stub,
+ GetWorkRequest request,
+ ThrottleTimer getWorkThrottleTimer,
+ Supplier<GetDataStream> getDataStream,
+ Supplier<CommitWorkStream> commitWorkStream,
+ WorkItemProcessor workItemProcessor) {
+ return GrpcDirectGetWorkStream.create(
+ responseObserver -> withDeadline(stub).getWorkStream(responseObserver),
+ request,
+ grpcBackOff.get(),
+ newStreamObserverFactory(),
+ streamRegistry,
+ logEveryNStreamFailures,
+ getWorkThrottleTimer,
+ getDataStream,
+ commitWorkStream,
+ workItemProcessor);
+ }
+
public GetDataStream createGetDataStream(
CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer
getDataThrottleTimer) {
return GrpcGetDataStream.create(
@@ -210,8 +233,9 @@ public final class GrpcWindmillStreamFactory implements
StatusDataProvider {
}
}
+ @Internal
@AutoBuilder(ofClass = GrpcWindmillStreamFactory.class)
- interface Builder {
+ public interface Builder {
Builder setJobHeader(JobHeader jobHeader);
Builder setLogEveryNStreamFailures(int logEveryNStreamFailures);
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java
new file mode 100644
index 00000000000..01783f6aa4d
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java
@@ -0,0 +1,401 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import javax.annotation.CheckReturnValue;
+import javax.annotation.concurrent.ThreadSafe;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher;
+import org.apache.beam.sdk.annotations.Internal;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Client for StreamingEngine. Given a {@link GetWorkBudget}, divides the
budget and starts the
+ * {@link WindmillStream.GetWorkStream}(s).
+ */
+@Internal
+@CheckReturnValue
+@ThreadSafe
+public final class StreamingEngineClient {
+ private static final Logger LOG =
LoggerFactory.getLogger(StreamingEngineClient.class);
+ private static final String PUBLISH_NEW_WORKER_METADATA_THREAD =
"PublishNewWorkerMetadataThread";
+ private static final String CONSUME_NEW_WORKER_METADATA_THREAD =
"ConsumeNewWorkerMetadataThread";
+
+ private final AtomicBoolean started;
+ private final JobHeader jobHeader;
+ private final GrpcWindmillStreamFactory streamFactory;
+ private final WorkItemProcessor workItemProcessor;
+ private final WindmillStubFactory stubFactory;
+ private final GrpcDispatcherClient dispatcherClient;
+ private final AtomicBoolean isBudgetRefreshPaused;
+ private final GetWorkBudgetRefresher getWorkBudgetRefresher;
+ private final AtomicReference<Instant> lastBudgetRefresh;
+ private final ThrottleTimer getWorkerMetadataThrottleTimer;
+ private final ExecutorService newWorkerMetadataPublisher;
+ private final ExecutorService newWorkerMetadataConsumer;
+ private final long clientId;
+ private final Supplier<GetWorkerMetadataStream> getWorkerMetadataStream;
+ private final Queue<WindmillEndpoints> newWindmillEndpoints;
+ /** Writes are guarded by synchronization, reads are lock free. */
+ private final AtomicReference<StreamingEngineConnectionState> connections;
+
+ @SuppressWarnings("FutureReturnValueIgnored")
+ private StreamingEngineClient(
+ JobHeader jobHeader,
+ GetWorkBudget totalGetWorkBudget,
+ AtomicReference<StreamingEngineConnectionState> connections,
+ GrpcWindmillStreamFactory streamFactory,
+ WorkItemProcessor workItemProcessor,
+ WindmillStubFactory stubFactory,
+ GetWorkBudgetDistributor getWorkBudgetDistributor,
+ GrpcDispatcherClient dispatcherClient,
+ long clientId) {
+ this.jobHeader = jobHeader;
+ this.started = new AtomicBoolean();
+ this.streamFactory = streamFactory;
+ this.workItemProcessor = workItemProcessor;
+ this.connections = connections;
+ this.stubFactory = stubFactory;
+ this.dispatcherClient = dispatcherClient;
+ this.isBudgetRefreshPaused = new AtomicBoolean(false);
+ this.getWorkerMetadataThrottleTimer = new ThrottleTimer();
+ this.newWorkerMetadataPublisher =
+ singleThreadedExecutorServiceOf(PUBLISH_NEW_WORKER_METADATA_THREAD);
+ this.newWorkerMetadataConsumer =
+ singleThreadedExecutorServiceOf(CONSUME_NEW_WORKER_METADATA_THREAD);
+ this.clientId = clientId;
+ this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH);
+ this.newWindmillEndpoints =
Queues.synchronizedQueue(EvictingQueue.create(1));
+ this.getWorkBudgetRefresher =
+ new GetWorkBudgetRefresher(
+ isBudgetRefreshPaused::get,
+ () -> {
+ getWorkBudgetDistributor.distributeBudget(
+ connections.get().windmillStreams().values(),
totalGetWorkBudget);
+ lastBudgetRefresh.set(Instant.now());
+ });
+ this.getWorkerMetadataStream =
+ Suppliers.memoize(
+ () ->
+ streamFactory.createGetWorkerMetadataStream(
+ dispatcherClient.getDispatcherStub(),
+ getWorkerMetadataThrottleTimer,
+ endpoints ->
+ // Run this on a separate thread than the grpc stream
thread.
+ newWorkerMetadataPublisher.submit(
+ () -> newWindmillEndpoints.add(endpoints))));
+ }
+
+ private static ExecutorService singleThreadedExecutorServiceOf(String
threadName) {
+ return Executors.newSingleThreadScheduledExecutor(
+ new ThreadFactoryBuilder()
+ .setNameFormat(threadName)
+ .setUncaughtExceptionHandler(
+ (t, e) -> {
+ LOG.error(
+ "{} failed due to uncaught exception during execution.
", t.getName(), e);
+ throw new StreamingEngineClientException(e);
+ })
+ .build());
+ }
+
+ /**
+ * Creates an instance of {@link StreamingEngineClient} and starts the {@link
+ * GetWorkerMetadataStream} with an RPC to the StreamingEngine backend.
{@link
+ * GetWorkerMetadataStream} will populate {@link #connections} when a
response is received.
+ *
+ * @implNote Does not block the calling thread.
+ */
+ public static StreamingEngineClient create(
+ JobHeader jobHeader,
+ GetWorkBudget totalGetWorkBudget,
+ GrpcWindmillStreamFactory streamingEngineStreamFactory,
+ WorkItemProcessor processWorkItem,
+ WindmillStubFactory windmillGrpcStubFactory,
+ GetWorkBudgetDistributor getWorkBudgetDistributor,
+ GrpcDispatcherClient dispatcherClient) {
+ StreamingEngineClient streamingEngineClient =
+ new StreamingEngineClient(
+ jobHeader,
+ totalGetWorkBudget,
+ new AtomicReference<>(StreamingEngineConnectionState.EMPTY),
+ streamingEngineStreamFactory,
+ processWorkItem,
+ windmillGrpcStubFactory,
+ getWorkBudgetDistributor,
+ dispatcherClient,
+ new Random().nextLong());
+ streamingEngineClient.startGetWorkerMetadataStream();
+ streamingEngineClient.startWorkerMetadataConsumer();
+ streamingEngineClient.getWorkBudgetRefresher.start();
+ return streamingEngineClient;
+ }
+
+ @VisibleForTesting
+ static StreamingEngineClient forTesting(
+ JobHeader jobHeader,
+ GetWorkBudget totalGetWorkBudget,
+ AtomicReference<StreamingEngineConnectionState> connections,
+ GrpcWindmillStreamFactory streamFactory,
+ WorkItemProcessor processWorkItem,
+ WindmillStubFactory stubFactory,
+ GetWorkBudgetDistributor getWorkBudgetDistributor,
+ GrpcDispatcherClient dispatcherClient,
+ long clientId) {
+ StreamingEngineClient streamingEngineClient =
+ new StreamingEngineClient(
+ jobHeader,
+ totalGetWorkBudget,
+ connections,
+ streamFactory,
+ processWorkItem,
+ stubFactory,
+ getWorkBudgetDistributor,
+ dispatcherClient,
+ clientId);
+ streamingEngineClient.startGetWorkerMetadataStream();
+ streamingEngineClient.startWorkerMetadataConsumer();
+ streamingEngineClient.getWorkBudgetRefresher.start();
+ return streamingEngineClient;
+ }
+
+ @SuppressWarnings("FutureReturnValueIgnored")
+ private void startWorkerMetadataConsumer() {
+ newWorkerMetadataConsumer.submit(
+ () -> {
+ while (true) {
+ Optional.ofNullable(newWindmillEndpoints.poll())
+ .ifPresent(this::consumeWindmillWorkerEndpoints);
+ }
+ });
+ }
+
+ @VisibleForTesting
+ boolean isWorkerMetadataReady() {
+ return !connections.get().equals(StreamingEngineConnectionState.EMPTY);
+ }
+
+ @VisibleForTesting
+ void finish() {
+ if (!started.compareAndSet(true, false)) {
+ return;
+ }
+ getWorkerMetadataStream.get().close();
+ getWorkBudgetRefresher.stop();
+ newWorkerMetadataPublisher.shutdownNow();
+ newWorkerMetadataConsumer.shutdownNow();
+ }
+
+ /**
+ * {@link java.util.function.Consumer<WindmillEndpoints>} used to update
{@link #connections} on
+ * new backend worker metadata.
+ */
+ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints
newWindmillEndpoints) {
+ isBudgetRefreshPaused.set(true);
+ LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints);
+ ImmutableMap<Endpoint, WindmillConnection> newWindmillConnections =
+ createNewWindmillConnections(newWindmillEndpoints.windmillEndpoints());
+
+ StreamingEngineConnectionState newConnectionsState =
+ StreamingEngineConnectionState.builder()
+ .setWindmillConnections(newWindmillConnections)
+ .setWindmillStreams(
+
closeStaleStreamsAndCreateNewStreams(newWindmillConnections.values()))
+ .setGlobalDataStreams(
+
createNewGlobalDataStreams(newWindmillEndpoints.globalDataEndpoints()))
+ .build();
+
+ LOG.info(
+ "Setting new connections: {}. Previous connections: {}.",
+ newConnectionsState,
+ connections.get());
+ connections.set(newConnectionsState);
+ isBudgetRefreshPaused.set(false);
+ getWorkBudgetRefresher.requestBudgetRefresh();
+ }
+
+ public final ImmutableList<Long> getAndResetThrottleTimes() {
+ StreamingEngineConnectionState currentConnections = connections.get();
+
+ ImmutableList<Long> keyedWorkStreamThrottleTimes =
+ currentConnections.windmillStreams().values().stream()
+ .map(WindmillStreamSender::getAndResetThrottleTime)
+ .collect(toImmutableList());
+
+ return ImmutableList.<Long>builder()
+ .add(getWorkerMetadataThrottleTimer.getAndResetThrottleTime())
+ .addAll(keyedWorkStreamThrottleTimes)
+ .build();
+ }
+
+ /** Starts {@link GetWorkerMetadataStream}. */
+ @SuppressWarnings({
+ "ReturnValueIgnored", // starts the stream, this value is memoized.
+ })
+ private void startGetWorkerMetadataStream() {
+ started.set(true);
+ getWorkerMetadataStream.get();
+ }
+
+ private synchronized ImmutableMap<Endpoint, WindmillConnection>
createNewWindmillConnections(
+ List<Endpoint> newWindmillEndpoints) {
+ ImmutableMap<Endpoint, WindmillConnection> currentConnections =
+ connections.get().windmillConnections();
+ return newWindmillEndpoints.stream()
+ .collect(
+ toImmutableMap(
+ Function.identity(),
+ // Reuse existing stubs if they exist.
+ endpoint ->
+ currentConnections.getOrDefault(
+ endpoint, WindmillConnection.from(endpoint,
this::createWindmillStub))));
+ }
+
+ private synchronized ImmutableMap<WindmillConnection, WindmillStreamSender>
+ closeStaleStreamsAndCreateNewStreams(Collection<WindmillConnection>
newWindmillConnections) {
+ ImmutableMap<WindmillConnection, WindmillStreamSender> currentStreams =
+ connections.get().windmillStreams();
+
+ // Close the streams that are no longer valid.
+ currentStreams.entrySet().stream()
+ .filter(
+ connectionAndStream ->
!newWindmillConnections.contains(connectionAndStream.getKey()))
+ .map(Entry::getValue)
+ .forEach(WindmillStreamSender::closeAllStreams);
+
+ return newWindmillConnections.stream()
+ .collect(
+ toImmutableMap(
+ Function.identity(),
+ newConnection ->
+ Optional.ofNullable(currentStreams.get(newConnection))
+ .orElseGet(() ->
createAndStartWindmillStreamSenderFor(newConnection))));
+ }
+
+ private ImmutableMap<String, Supplier<GetDataStream>>
createNewGlobalDataStreams(
+ ImmutableMap<String, Endpoint> newGlobalDataEndpoints) {
+ ImmutableMap<String, Supplier<GetDataStream>> currentGlobalDataStreams =
+ connections.get().globalDataStreams();
+ return newGlobalDataEndpoints.entrySet().stream()
+ .collect(
+ toImmutableMap(
+ Entry::getKey,
+ keyedEndpoint ->
+ existingOrNewGetDataStreamFor(keyedEndpoint,
currentGlobalDataStreams)));
+ }
+
+ private Supplier<GetDataStream> existingOrNewGetDataStreamFor(
+ Entry<String, Endpoint> keyedEndpoint,
+ ImmutableMap<String, Supplier<GetDataStream>> currentGlobalDataStreams) {
+ return Preconditions.checkNotNull(
+ currentGlobalDataStreams.getOrDefault(
+ keyedEndpoint.getKey(),
+ () ->
+ streamFactory.createGetDataStream(
+ newOrExistingStubFor(keyedEndpoint.getValue()), new
ThrottleTimer())));
+ }
+
+ private CloudWindmillServiceV1Alpha1Stub newOrExistingStubFor(Endpoint
endpoint) {
+ return
Optional.ofNullable(connections.get().windmillConnections().get(endpoint))
+ .map(WindmillConnection::stub)
+ .orElseGet(() -> createWindmillStub(endpoint));
+ }
+
+ private WindmillStreamSender createAndStartWindmillStreamSenderFor(
+ WindmillConnection connection) {
+ // Initially create each stream with no budget. The budget will be
eventually assigned by the
+ // GetWorkBudgetDistributor.
+ WindmillStreamSender windmillStreamSender =
+ WindmillStreamSender.create(
+ connection.stub(),
+ GetWorkRequest.newBuilder()
+ .setClientId(clientId)
+ .setJobId(jobHeader.getJobId())
+ .setProjectId(jobHeader.getProjectId())
+ .setWorkerId(jobHeader.getWorkerId())
+ .build(),
+ GetWorkBudget.noBudget(),
+ streamFactory,
+ workItemProcessor);
+ windmillStreamSender.startStreams();
+ return windmillStreamSender;
+ }
+
+ private CloudWindmillServiceV1Alpha1Stub createWindmillStub(Endpoint
endpoint) {
+ switch (stubFactory.getKind()) {
+ // This is only used in tests.
+ case IN_PROCESS:
+ return stubFactory.inProcess().get();
+ // Create stub for direct_endpoint or just default to Dispatcher stub.
+ case REMOTE:
+ return endpoint
+ .directEndpoint()
+ .map(stubFactory.remote())
+ .orElseGet(dispatcherClient::getDispatcherStub);
+ // Should never be called, this switch statement is exhaustive.
+ default:
+ throw new UnsupportedOperationException(
+ "Only remote or in-process stub factories are available.");
+ }
+ }
+
+ private static class StreamingEngineClientException extends
IllegalStateException {
+
+ private StreamingEngineClientException(Throwable exception) {
+ super(exception);
+ }
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineConnectionState.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineConnectionState.java
new file mode 100644
index 00000000000..8d784456d65
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineConnectionState.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import com.google.auto.value.AutoValue;
+import java.util.function.Supplier;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+
+/**
+ * Represents the current state of connections to Streaming Engine.
Connections are updated when
+ * backend workers assigned to the key ranges being processed by this user
worker change during
+ * pipeline execution. For example, changes can happen via autoscaling,
load-balancing, or other
+ * backend updates.
+ */
+@AutoValue
+abstract class StreamingEngineConnectionState {
+ static final StreamingEngineConnectionState EMPTY = builder().build();
+
+ static Builder builder() {
+ return new AutoValue_StreamingEngineConnectionState.Builder()
+ .setWindmillConnections(ImmutableMap.of())
+ .setWindmillStreams(ImmutableMap.of())
+ .setGlobalDataStreams(ImmutableMap.of());
+ }
+
+ abstract ImmutableMap<Endpoint, WindmillConnection> windmillConnections();
+
+ abstract ImmutableMap<WindmillConnection, WindmillStreamSender>
windmillStreams();
+
+ /** Mapping of GlobalDataIds and the direct GetDataStreams used fetch them.
*/
+ abstract ImmutableMap<String, Supplier<GetDataStream>> globalDataStreams();
+
+ @AutoValue.Builder
+ abstract static class Builder {
+ public abstract Builder setWindmillConnections(
+ ImmutableMap<Endpoint, WindmillConnection> value);
+
+ public abstract Builder setWindmillStreams(
+ ImmutableMap<WindmillConnection, WindmillStreamSender> value);
+
+ public abstract Builder setGlobalDataStreams(
+ ImmutableMap<String, Supplier<GetDataStream>> value);
+
+ public abstract StreamingEngineConnectionState build();
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java
new file mode 100644
index 00000000000..bef710329ff
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
+import javax.annotation.concurrent.ThreadSafe;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.StreamingEngineThrottleTimers;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import org.apache.beam.sdk.annotations.Internal;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
+
+/**
+ * Owns and maintains a set of streams used to communicate with a specific
Windmill worker.
+ * Underlying streams are "cached" in a threadsafe manner so that once {@link
Supplier#get} is
+ * called, a stream that is already started is returned.
+ *
+ * <p>Holds references to {@link
+ *
Supplier<org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream>}
because
+ * initializing the streams automatically start them, and we want to do so
lazily here once the
+ * {@link GetWorkBudget} is set.
+ *
+ * <p>Once started, the underlying streams are "alive" until they are manually
closed via {@link
+ * #closeAllStreams()}.
+ *
+ * <p>If closed, it means that the backend endpoint is no longer in the worker
set. Once closed,
+ * these instances are not reused.
+ *
+ * @implNote Does not manage streams for fetching {@link
+ * org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData}
for side inputs.
+ */
+@Internal
+@ThreadSafe
+public class WindmillStreamSender {
+ private final AtomicBoolean started;
+ private final AtomicReference<GetWorkBudget> getWorkBudget;
+ private final Supplier<GetWorkStream> getWorkStream;
+ private final Supplier<GetDataStream> getDataStream;
+ private final Supplier<CommitWorkStream> commitWorkStream;
+ private final StreamingEngineThrottleTimers streamingEngineThrottleTimers;
+
+ private WindmillStreamSender(
+ CloudWindmillServiceV1Alpha1Stub stub,
+ GetWorkRequest getWorkRequest,
+ AtomicReference<GetWorkBudget> getWorkBudget,
+ GrpcWindmillStreamFactory streamingEngineStreamFactory,
+ WorkItemProcessor workItemProcessor) {
+ this.started = new AtomicBoolean(false);
+ this.getWorkBudget = getWorkBudget;
+ this.streamingEngineThrottleTimers =
StreamingEngineThrottleTimers.create();
+
+ // All streams are memoized/cached since they are expensive to create and
some implementations
+ // perform side effects on construction (i.e. sending initial requests to
the stream server to
+ // initiate the streaming RPC connection). Stream instances
connect/reconnect internally so we
+ // can reuse the same instance through the entire lifecycle of
WindmillStreamSender.
+ this.getDataStream =
+ Suppliers.memoize(
+ () ->
+ streamingEngineStreamFactory.createGetDataStream(
+ stub,
streamingEngineThrottleTimers.getDataThrottleTimer()));
+ this.commitWorkStream =
+ Suppliers.memoize(
+ () ->
+ streamingEngineStreamFactory.createCommitWorkStream(
+ stub,
streamingEngineThrottleTimers.commitWorkThrottleTimer()));
+ this.getWorkStream =
+ Suppliers.memoize(
+ () ->
+ streamingEngineStreamFactory.createDirectGetWorkStream(
+ stub,
+ withRequestBudget(getWorkRequest, getWorkBudget.get()),
+ streamingEngineThrottleTimers.getWorkThrottleTimer(),
+ getDataStream,
+ commitWorkStream,
+ workItemProcessor));
+ }
+
+ public static WindmillStreamSender create(
+ CloudWindmillServiceV1Alpha1Stub stub,
+ GetWorkRequest getWorkRequest,
+ GetWorkBudget getWorkBudget,
+ GrpcWindmillStreamFactory streamingEngineStreamFactory,
+ WorkItemProcessor workItemProcessor) {
+ return new WindmillStreamSender(
+ stub,
+ getWorkRequest,
+ new AtomicReference<>(getWorkBudget),
+ streamingEngineStreamFactory,
+ workItemProcessor);
+ }
+
+ private static GetWorkRequest withRequestBudget(GetWorkRequest request,
GetWorkBudget budget) {
+ return
request.toBuilder().setMaxItems(budget.items()).setMaxBytes(budget.bytes()).build();
+ }
+
+ @SuppressWarnings("ReturnValueIgnored")
+ void startStreams() {
+ getWorkStream.get();
+ getDataStream.get();
+ commitWorkStream.get();
+ // *stream.get() is all memoized in a threadsafe manner.
+ started.set(true);
+ }
+
+ void closeAllStreams() {
+ // Supplier<Stream>.get() starts the stream which is an expensive
operation as it initiates the
+ // streaming RPCs by possibly making calls over the network. Do not close
the streams unless
+ // they have already been started.
+ if (started.get()) {
+ getWorkStream.get().close();
+ getDataStream.get().close();
+ commitWorkStream.get().close();
+ }
+ }
+
+ public void adjustBudget(long itemsDelta, long bytesDelta) {
+ getWorkBudget.set(getWorkBudget.get().apply(itemsDelta, bytesDelta));
+ if (started.get()) {
+ getWorkStream.get().adjustBudget(itemsDelta, bytesDelta);
+ }
+ }
+
+ public void adjustBudget(GetWorkBudget adjustment) {
+ adjustBudget(adjustment.items(), adjustment.bytes());
+ }
+
+ public GetWorkBudget remainingGetWorkBudget() {
+ return started.get() ? getWorkStream.get().remainingBudget() :
getWorkBudget.get();
+ }
+
+ public long getAndResetThrottleTime() {
+ return streamingEngineThrottleTimers.getAndResetThrottleTime();
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/ProcessWorkItemClient.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/ProcessWorkItemClient.java
new file mode 100644
index 00000000000..1adfe02f45f
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/ProcessWorkItemClient.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.work;
+
+import com.google.auto.value.AutoValue;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import org.apache.beam.sdk.annotations.Internal;
+
+/**
+ * A client context to process {@link WorkItem} and route all subsequent
Windmill WorkItem API calls
+ * to the same backend worker. Wraps the {@link WorkItem}.
+ */
+@AutoValue
+@Internal
+public abstract class ProcessWorkItemClient {
+ public static ProcessWorkItemClient create(
+ WorkItem workItem, GetDataStream getDataStream, CommitWorkStream
commitWorkStream) {
+ return new AutoValue_ProcessWorkItemClient(workItem, getDataStream,
commitWorkStream);
+ }
+
+ /** {@link WorkItem} being processed. */
+ public abstract WorkItem workItem();
+
+ /**
+ * {@link GetDataStream} that connects to the backend Windmill worker
handling the {@link
+ * WorkItem}.
+ */
+ public abstract GetDataStream getDataStream();
+
+ /**
+ * {@link CommitWorkStream} that connects to backend Windmill worker
handling the {@link
+ * WorkItem}.
+ */
+ public abstract CommitWorkStream commitWorkStream();
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/WorkItemProcessor.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/WorkItemProcessor.java
new file mode 100644
index 00000000000..4ebc77775fc
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/WorkItemProcessor.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.work;
+
+import java.util.Collection;
+import java.util.function.Consumer;
+import javax.annotation.CheckReturnValue;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
+import org.apache.beam.sdk.annotations.Internal;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Instant;
+
+@FunctionalInterface
+@CheckReturnValue
+@Internal
+public interface WorkItemProcessor {
+ /**
+ * Receives and processes {@link WorkItem}(s) wrapped in its {@link
ProcessWorkItemClient}
+ * processing context.
+ *
+ * @param computation the Computation that the Work belongs to.
+ * @param inputDataWatermark Watermark of when the input data was received
by the computation.
+ * @param synchronizedProcessingTime Aggregate system watermark that also
depends on each
+ * computation's received dependent system watermark value to propagate
the system watermark
+ * downstream.
+ * @param wrappedWorkItem A workItem and it's processing context, used to
route subsequent
+ * WorkItem API (GetData, CommitWork) RPC calls to the same backend
worker, where the WorkItem
+ * was returned from GetWork.
+ * @param ackWorkItemQueued Called after an attempt to queue the work item
for processing. Used to
+ * free up pending budget.
+ * @param getWorkStreamLatencies Latencies per processing stage for the
WorkItem for reporting
+ * back to Streaming Engine backend.
+ */
+ void processWork(
+ String computation,
+ @Nullable Instant inputDataWatermark,
+ @Nullable Instant synchronizedProcessingTime,
+ ProcessWorkItemClient wrappedWorkItem,
+ Consumer<WorkItem> ackWorkItemQueued,
+ Collection<LatencyAttribution> getWorkStreamLatencies);
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java
new file mode 100644
index 00000000000..3a17222d3e6
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
+
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.DoubleMath.roundToLong;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath.divide;
+
+import java.math.RoundingMode;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.WindmillStreamSender;
+import org.apache.beam.sdk.annotations.Internal;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Evenly distributes the provided budget across the available {@link
WindmillStreamSender}(s). */
+@Internal
+final class EvenGetWorkBudgetDistributor implements GetWorkBudgetDistributor {
+ private static final Logger LOG =
LoggerFactory.getLogger(EvenGetWorkBudgetDistributor.class);
+ private final Supplier<GetWorkBudget> activeWorkBudgetSupplier;
+
+ EvenGetWorkBudgetDistributor(Supplier<GetWorkBudget>
activeWorkBudgetSupplier) {
+ this.activeWorkBudgetSupplier = activeWorkBudgetSupplier;
+ }
+
+ private static boolean isBelowFiftyPercentOfTarget(
+ GetWorkBudget remaining, GetWorkBudget target) {
+ return remaining.items() < roundToLong(target.items() * 0.5,
RoundingMode.CEILING)
+ || remaining.bytes() < roundToLong(target.bytes() * 0.5,
RoundingMode.CEILING);
+ }
+
+ @Override
+ public void distributeBudget(
+ ImmutableCollection<WindmillStreamSender> streams, GetWorkBudget
getWorkBudget) {
+ if (streams.isEmpty()) {
+ LOG.debug("Cannot distribute budget to no streams.");
+ return;
+ }
+
+ if (getWorkBudget.equals(GetWorkBudget.noBudget())) {
+ LOG.debug("Cannot distribute 0 budget.");
+ return;
+ }
+
+ Map<WindmillStreamSender, GetWorkBudget> desiredBudgets =
+ computeDesiredBudgets(streams, getWorkBudget);
+
+ for (Entry<WindmillStreamSender, GetWorkBudget> streamAndDesiredBudget :
+ desiredBudgets.entrySet()) {
+ WindmillStreamSender stream = streamAndDesiredBudget.getKey();
+ GetWorkBudget desired = streamAndDesiredBudget.getValue();
+ GetWorkBudget remaining = stream.remainingGetWorkBudget();
+ if (isBelowFiftyPercentOfTarget(remaining, desired)) {
+ GetWorkBudget adjustment = desired.subtract(remaining);
+ stream.adjustBudget(adjustment);
+ }
+ }
+ }
+
+ private ImmutableMap<WindmillStreamSender, GetWorkBudget>
computeDesiredBudgets(
+ ImmutableCollection<WindmillStreamSender> streams, GetWorkBudget
totalGetWorkBudget) {
+ GetWorkBudget activeWorkBudget = activeWorkBudgetSupplier.get();
+ LOG.info("Current active work budget: {}", activeWorkBudget);
+ // TODO: Fix possibly non-deterministic handing out of budgets.
+ // Rounding up here will drift upwards over the lifetime of the streams.
+ GetWorkBudget budgetPerStream =
+ GetWorkBudget.builder()
+ .setItems(
+ divide(
+ totalGetWorkBudget.items() - activeWorkBudget.items(),
+ streams.size(),
+ RoundingMode.CEILING))
+ .setBytes(
+ divide(
+ totalGetWorkBudget.bytes() - activeWorkBudget.bytes(),
+ streams.size(),
+ RoundingMode.CEILING))
+ .build();
+ return streams.stream().collect(toImmutableMap(Function.identity(), unused
-> budgetPerStream));
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudget.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudget.java
index 0038e3e9cc6..bc82b622ce6 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudget.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudget.java
@@ -20,13 +20,14 @@ package
org.apache.beam.runners.dataflow.worker.windmill.work.budget;
import com.google.auto.value.AutoValue;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import org.apache.beam.sdk.annotations.Internal;
/**
* Budget of items and bytes for fetching {@link
* org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem}(s) via
{@link
* WindmillStream.GetWorkStream}. Used to control how "much" work is returned
from Windmill.
*/
+@Internal
@AutoValue
public abstract class GetWorkBudget {
public static GetWorkBudget.Builder builder() {
@@ -46,29 +47,26 @@ public abstract class GetWorkBudget {
}
/**
- * Adds the given bytes and items or the current budget, returning a new
{@link GetWorkBudget}.
- * Does not drop below 0.
+ * Applies the given bytes and items delta to the current budget, returning
a new {@link
+ * GetWorkBudget}. Does not drop below 0.
*/
- public GetWorkBudget add(long items, long bytes) {
- Preconditions.checkArgument(items >= 0 && bytes >= 0);
- return GetWorkBudget.builder().setBytes(bytes() + bytes).setItems(items()
+ items).build();
+ public GetWorkBudget apply(long itemsDelta, long bytesDelta) {
+ return GetWorkBudget.builder()
+ .setBytes(bytes() + bytesDelta)
+ .setItems(items() + itemsDelta)
+ .build();
}
- public GetWorkBudget add(GetWorkBudget other) {
- return add(other.items(), other.bytes());
+ public GetWorkBudget apply(GetWorkBudget other) {
+ return apply(other.items(), other.bytes());
}
- /**
- * Subtracts the given bytes and items or the current budget, returning a
new {@link
- * GetWorkBudget}. Does not drop below 0.
- */
- public GetWorkBudget subtract(long items, long bytes) {
- Preconditions.checkArgument(items >= 0 && bytes >= 0);
- return GetWorkBudget.builder().setBytes(bytes() - bytes).setItems(items()
- items).build();
+ public GetWorkBudget subtract(GetWorkBudget other) {
+ return apply(-other.items(), -other.bytes());
}
- public GetWorkBudget subtract(GetWorkBudget other) {
- return subtract(other.items(), other.bytes());
+ public GetWorkBudget subtract(long items, long bytes) {
+ return apply(-items, -bytes);
}
/** Budget of bytes for GetWork. Does not drop below 0. */
@@ -77,6 +75,9 @@ public abstract class GetWorkBudget {
/** Budget of items for GetWork. Does not drop below 0. */
public abstract long items();
+ public abstract GetWorkBudget.Builder toBuilder();
+
+ @Internal
@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setBytes(long bytes);
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributor.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributor.java
new file mode 100644
index 00000000000..3ec9718e041
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributor.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
+
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.WindmillStreamSender;
+import org.apache.beam.sdk.annotations.Internal;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+
+/**
+ * Distributes the total {@link GetWorkBudget} amongst the {@link
+ *
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream}(s)
to
+ * Windmill.
+ */
+@Internal
+public interface GetWorkBudgetDistributor {
+ void distributeBudget(
+ ImmutableCollection<WindmillStreamSender> streams, GetWorkBudget
getWorkBudget);
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java
new file mode 100644
index 00000000000..43c0d46139d
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
+
+import java.util.function.Supplier;
+import org.apache.beam.sdk.annotations.Internal;
+
+@Internal
+public final class GetWorkBudgetDistributors {
+ public static GetWorkBudgetDistributor distributeEvenly(
+ Supplier<GetWorkBudget> activeWorkBudgetSupplier) {
+ return new EvenGetWorkBudgetDistributor(activeWorkBudgetSupplier);
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java
new file mode 100644
index 00000000000..e39aa8dbc8a
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.function.Supplier;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.fn.stream.AdvancingPhaser;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Handles refreshing the budget either via triggered or scheduled execution
using a {@link
+ * java.util.concurrent.Phaser} to emulate publish/subscribe pattern.
+ */
+@Internal
+@ThreadSafe
+public final class GetWorkBudgetRefresher {
+ @VisibleForTesting public static final int SCHEDULED_BUDGET_REFRESH_MILLIS =
100;
+ private static final int INITIAL_BUDGET_REFRESH_PHASE = 0;
+ private static final String BUDGET_REFRESH_THREAD =
"GetWorkBudgetRefreshThread";
+ private static final Logger LOG =
LoggerFactory.getLogger(GetWorkBudgetRefresher.class);
+
+ private final AdvancingPhaser budgetRefreshTrigger;
+ private final ExecutorService budgetRefreshExecutor;
+ private final Supplier<Boolean> isBudgetRefreshPaused;
+ private final Runnable redistributeBudget;
+
+ public GetWorkBudgetRefresher(
+ Supplier<Boolean> isBudgetRefreshPaused, Runnable redistributeBudget) {
+ this.budgetRefreshTrigger = new AdvancingPhaser(1);
+ this.budgetRefreshExecutor =
+ Executors.newSingleThreadScheduledExecutor(
+ new ThreadFactoryBuilder()
+ .setNameFormat(BUDGET_REFRESH_THREAD)
+ .setUncaughtExceptionHandler(
+ (t, e) ->
+ LOG.error(
+ "{} failed due to uncaught exception during
execution. ",
+ t.getName(),
+ e))
+ .build());
+ this.isBudgetRefreshPaused = isBudgetRefreshPaused;
+ this.redistributeBudget = redistributeBudget;
+ }
+
+ @SuppressWarnings("FutureReturnValueIgnored")
+ public void start() {
+ budgetRefreshExecutor.submit(this::subscribeToRefreshBudget);
+ }
+
+ /** Publishes an event to trigger a budget refresh. */
+ public void requestBudgetRefresh() {
+ budgetRefreshTrigger.arrive();
+ }
+
+ public void stop() {
+ budgetRefreshTrigger.arriveAndDeregister();
+ // Put the budgetRefreshTrigger in a terminated state,
#waitForBudgetRefreshEventWithTimeout
+ // will subsequently return false, and #subscribeToRefreshBudget will
return, completing the
+ // task.
+ budgetRefreshTrigger.forceTermination();
+ budgetRefreshExecutor.shutdownNow();
+ }
+
+ private void subscribeToRefreshBudget() {
+ int currentBudgetRefreshPhase = INITIAL_BUDGET_REFRESH_PHASE;
+ // Runs forever until #stop is called.
+ while (true) {
+ currentBudgetRefreshPhase =
waitForBudgetRefreshEventWithTimeout(currentBudgetRefreshPhase);
+ // Phaser.awaitAdvanceInterruptibly(...) returns a negative value if the
phaser is
+ // terminated, else returns when either a budget refresh has been
manually triggered or
+ // SCHEDULED_BUDGET_REFRESH_MILLIS have passed.
+ if (currentBudgetRefreshPhase < 0) {
+ return;
+ }
+ // Budget refreshes are paused during endpoint updates.
+ if (!isBudgetRefreshPaused.get()) {
+ redistributeBudget.run();
+ }
+ }
+ }
+
+ /**
+ * Waits for a budget refresh trigger event with a timeout. Returns the
current phase of the
+ * {@link #budgetRefreshTrigger}, to be used for following waits for the
{@link
+ * #budgetRefreshTrigger} to advance.
+ *
+ * <p>Budget refresh event is triggered when {@link #budgetRefreshTrigger}
moves on from the given
+ * currentBudgetRefreshPhase.
+ */
+ private int waitForBudgetRefreshEventWithTimeout(int
currentBudgetRefreshPhase) {
+ try {
+ // Wait for budgetRefreshTrigger to advance FROM the current phase.
+ return budgetRefreshTrigger.awaitAdvanceInterruptibly(
+ currentBudgetRefreshPhase, SCHEDULED_BUDGET_REFRESH_MILLIS,
TimeUnit.MILLISECONDS);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new BudgetRefreshException("Error occurred waiting for budget
refresh.", e);
+ } catch (TimeoutException ignored) {
+ // Intentionally do nothing since we trigger the budget refresh on the
timeout.
+ }
+
+ return currentBudgetRefreshPhase;
+ }
+
+ private static class BudgetRefreshException extends RuntimeException {
+ private BudgetRefreshException(String msg, Throwable sourceException) {
+ super(msg, sourceException);
+ }
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java
new file mode 100644
index 00000000000..8a2c643a5b7
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java
@@ -0,0 +1,417 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.atLeast;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Server;
+import
org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessServerBuilder;
+import
org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessSocketAddress;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.util.MutableHandlerRegistry;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class StreamingEngineClientTest {
+ private static final WindmillServiceAddress DEFAULT_WINDMILL_SERVICE_ADDRESS
=
+
WindmillServiceAddress.create(HostAndPort.fromParts(WindmillChannelFactory.LOCALHOST,
443));
+ private static final ImmutableMap<String, WorkerMetadataResponse.Endpoint>
DEFAULT =
+ ImmutableMap.of(
+ "global_data",
+ WorkerMetadataResponse.Endpoint.newBuilder()
+
.setDirectEndpoint(DEFAULT_WINDMILL_SERVICE_ADDRESS.gcpServiceAddress().toString())
+ .build());
+ private static final long CLIENT_ID = 1L;
+ private static final String JOB_ID = "jobId";
+ private static final String PROJECT_ID = "projectId";
+ private static final String WORKER_ID = "workerId";
+ private static final JobHeader JOB_HEADER =
+ JobHeader.newBuilder()
+ .setJobId(JOB_ID)
+ .setProjectId(PROJECT_ID)
+ .setWorkerId(WORKER_ID)
+ .build();
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ private final Set<ManagedChannel> channels = new HashSet<>();
+ private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
+
+ private final GrpcWindmillStreamFactory streamFactory =
+ spy(GrpcWindmillStreamFactory.of(JOB_HEADER).build());
+ private final WindmillStubFactory stubFactory =
+ WindmillStubFactory.inProcessStubFactory(
+ "StreamingEngineClientTest",
+ name -> {
+ ManagedChannel channel =
+
grpcCleanup.register(WindmillChannelFactory.inProcessChannel(name));
+ channels.add(channel);
+ return channel;
+ });
+ private final GrpcDispatcherClient dispatcherClient =
+ GrpcDispatcherClient.forTesting(stubFactory, new ArrayList<>(), new
HashSet<>());
+ private final GetWorkBudgetDistributor getWorkBudgetDistributor =
+ spy(new TestGetWorkBudgetDistributor());
+ private final AtomicReference<StreamingEngineConnectionState> connections =
+ new AtomicReference<>(StreamingEngineConnectionState.EMPTY);
+ private Server fakeStreamingEngineServer;
+ private CountDownLatch getWorkerMetadataReady;
+ private GetWorkerMetadataTestStub fakeGetWorkerMetadataStub;
+
+ private StreamingEngineClient streamingEngineClient;
+
+ private static WorkItemProcessor noOpProcessWorkItemFn() {
+ return (computation,
+ inputDataWatermark,
+ synchronizedProcessingTime,
+ workItem,
+ ackQueuedWorkItem,
+ getWorkStreamLatencies) -> {};
+ }
+
+ private static GetWorkRequest getWorkRequest(long items, long bytes) {
+ return GetWorkRequest.newBuilder()
+ .setJobId(JOB_ID)
+ .setProjectId(PROJECT_ID)
+ .setWorkerId(WORKER_ID)
+ .setClientId(CLIENT_ID)
+ .setMaxItems(items)
+ .setMaxBytes(bytes)
+ .build();
+ }
+
+ private static WorkerMetadataResponse.Endpoint
metadataResponseEndpoint(String workerToken) {
+ return
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken).build();
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ channels.forEach(ManagedChannel::shutdownNow);
+ channels.clear();
+ fakeStreamingEngineServer =
+ grpcCleanup.register(
+ InProcessServerBuilder.forName("StreamingEngineClientTest")
+ .fallbackHandlerRegistry(serviceRegistry)
+ .executor(Executors.newFixedThreadPool(1))
+ .build());
+
+ fakeStreamingEngineServer.start();
+ dispatcherClient.consumeWindmillDispatcherEndpoints(
+ ImmutableSet.of(
+ HostAndPort.fromString(
+ new
InProcessSocketAddress("StreamingEngineClientTest").toString())));
+ getWorkerMetadataReady = new CountDownLatch(1);
+ fakeGetWorkerMetadataStub = new
GetWorkerMetadataTestStub(getWorkerMetadataReady);
+ serviceRegistry.addService(fakeGetWorkerMetadataStub);
+ }
+
+ @After
+ public void cleanUp() {
+ fakeGetWorkerMetadataStub.close();
+ fakeStreamingEngineServer.shutdownNow();
+ channels.forEach(ManagedChannel::shutdownNow);
+ Preconditions.checkNotNull(streamingEngineClient).finish();
+ }
+
+ private StreamingEngineClient newStreamingEngineClient(
+ GetWorkBudget getWorkBudget, WorkItemProcessor workItemProcessor) {
+ return StreamingEngineClient.forTesting(
+ JOB_HEADER,
+ getWorkBudget,
+ connections,
+ streamFactory,
+ workItemProcessor,
+ stubFactory,
+ getWorkBudgetDistributor,
+ dispatcherClient,
+ CLIENT_ID);
+ }
+
+ @Test
+ public void testStreamsStartCorrectly() throws InterruptedException {
+ long items = 10L;
+ long bytes = 10L;
+
+ streamingEngineClient =
+ newStreamingEngineClient(
+ GetWorkBudget.builder().setItems(items).setBytes(bytes).build(),
+ noOpProcessWorkItemFn());
+
+ String workerToken = "workerToken1";
+ String workerToken2 = "workerToken2";
+
+ WorkerMetadataResponse firstWorkerMetadata =
+ WorkerMetadataResponse.newBuilder()
+ .setMetadataVersion(1)
+ .addWorkEndpoints(metadataResponseEndpoint(workerToken))
+ .addWorkEndpoints(metadataResponseEndpoint(workerToken2))
+ .putAllGlobalDataEndpoints(DEFAULT)
+ .build();
+
+ getWorkerMetadataReady.await();
+ fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
+ StreamingEngineConnectionState currentConnections =
waitForWorkerMetadataToBeConsumed(1);
+
+ assertEquals(2, currentConnections.windmillConnections().size());
+ assertEquals(2, currentConnections.windmillStreams().size());
+ Set<String> workerTokens =
+ connections.get().windmillConnections().values().stream()
+ .map(WindmillConnection::backendWorkerToken)
+ .filter(Optional::isPresent)
+ .map(Optional::get)
+ .collect(Collectors.toSet());
+
+ assertTrue(workerTokens.contains(workerToken));
+ assertTrue(workerTokens.contains(workerToken2));
+
+ verify(getWorkBudgetDistributor, atLeast(1))
+ .distributeBudget(
+ any(),
eq(GetWorkBudget.builder().setItems(items).setBytes(bytes).build()));
+
+ verify(streamFactory, times(2))
+ .createDirectGetWorkStream(
+ any(), eq(getWorkRequest(0, 0)), any(), any(), any(),
eq(noOpProcessWorkItemFn()));
+
+ verify(streamFactory, times(2)).createGetDataStream(any(), any());
+ verify(streamFactory, times(2)).createCommitWorkStream(any(), any());
+ }
+
+ @Test
+ public void testScheduledBudgetRefresh() throws InterruptedException {
+ streamingEngineClient =
+ newStreamingEngineClient(
+ GetWorkBudget.builder().setItems(1L).setBytes(1L).build(),
noOpProcessWorkItemFn());
+
+ getWorkerMetadataReady.await();
+ fakeGetWorkerMetadataStub.injectWorkerMetadata(
+ WorkerMetadataResponse.newBuilder()
+ .setMetadataVersion(1)
+ .addWorkEndpoints(metadataResponseEndpoint("workerToken"))
+ .putAllGlobalDataEndpoints(DEFAULT)
+ .build());
+ waitForWorkerMetadataToBeConsumed(1);
+ Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS);
+ verify(getWorkBudgetDistributor, atLeast(2)).distributeBudget(any(),
any());
+ }
+
+ @Test
+ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers()
+ throws InterruptedException {
+ streamingEngineClient =
+ newStreamingEngineClient(
+ GetWorkBudget.builder().setItems(1).setBytes(1).build(),
noOpProcessWorkItemFn());
+
+ String workerToken = "workerToken1";
+ String workerToken2 = "workerToken2";
+ String workerToken3 = "workerToken3";
+
+ WorkerMetadataResponse firstWorkerMetadata =
+ WorkerMetadataResponse.newBuilder()
+ .setMetadataVersion(1)
+ .addWorkEndpoints(
+
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken).build())
+ .addWorkEndpoints(
+
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken2).build())
+ .putAllGlobalDataEndpoints(DEFAULT)
+ .build();
+ WorkerMetadataResponse secondWorkerMetadata =
+ WorkerMetadataResponse.newBuilder()
+ .setMetadataVersion(2)
+ .addWorkEndpoints(
+
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken3).build())
+ .putAllGlobalDataEndpoints(DEFAULT)
+ .build();
+
+ getWorkerMetadataReady.await();
+ fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
+ fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata);
+
+ StreamingEngineConnectionState currentConnections =
waitForWorkerMetadataToBeConsumed(2);
+
+ assertEquals(1, currentConnections.windmillConnections().size());
+ assertEquals(1, currentConnections.windmillStreams().size());
+ Set<String> workerTokens =
+ connections.get().windmillConnections().values().stream()
+ .map(WindmillConnection::backendWorkerToken)
+ .filter(Optional::isPresent)
+ .map(Optional::get)
+ .collect(Collectors.toSet());
+
+ assertFalse(workerTokens.contains(workerToken));
+ assertFalse(workerTokens.contains(workerToken2));
+ }
+
+ @Test
+ public void testOnNewWorkerMetadata_redistributesBudget() throws
InterruptedException {
+ streamingEngineClient =
+ newStreamingEngineClient(
+ GetWorkBudget.builder().setItems(1).setBytes(1).build(),
noOpProcessWorkItemFn());
+
+ String workerToken = "workerToken1";
+ String workerToken2 = "workerToken2";
+ String workerToken3 = "workerToken3";
+
+ WorkerMetadataResponse firstWorkerMetadata =
+ WorkerMetadataResponse.newBuilder()
+ .setMetadataVersion(1)
+ .addWorkEndpoints(
+
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken).build())
+ .putAllGlobalDataEndpoints(DEFAULT)
+ .build();
+ WorkerMetadataResponse secondWorkerMetadata =
+ WorkerMetadataResponse.newBuilder()
+ .setMetadataVersion(2)
+ .addWorkEndpoints(
+
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken2).build())
+ .putAllGlobalDataEndpoints(DEFAULT)
+ .build();
+ WorkerMetadataResponse thirdWorkerMetadata =
+ WorkerMetadataResponse.newBuilder()
+ .setMetadataVersion(3)
+ .addWorkEndpoints(
+
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken3).build())
+ .putAllGlobalDataEndpoints(DEFAULT)
+ .build();
+
+ getWorkerMetadataReady.await();
+ fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
+ Thread.sleep(50);
+ fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata);
+ Thread.sleep(50);
+ fakeGetWorkerMetadataStub.injectWorkerMetadata(thirdWorkerMetadata);
+ Thread.sleep(50);
+ verify(getWorkBudgetDistributor, atLeast(3)).distributeBudget(any(),
any());
+ }
+
+ private StreamingEngineConnectionState waitForWorkerMetadataToBeConsumed(
+ int expectedMetadataConsumed) throws InterruptedException {
+ int currentMetadataConsumed = 0;
+ StreamingEngineConnectionState currentConsumedMetadata =
StreamingEngineConnectionState.EMPTY;
+ while (true) {
+ if (!connections.get().equals(currentConsumedMetadata)) {
+ ++currentMetadataConsumed;
+ if (currentMetadataConsumed == expectedMetadataConsumed) {
+ break;
+ }
+ currentConsumedMetadata = connections.get();
+ }
+ }
+ // Wait for metadata to be consumed and budgets to be redistributed.
+ Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS);
+ return connections.get();
+ }
+
+ private static class GetWorkerMetadataTestStub
+ extends
CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase {
+ private static final WorkerMetadataResponse CLOSE_ALL_STREAMS =
+ WorkerMetadataResponse.newBuilder().setMetadataVersion(100).build();
+ private final CountDownLatch ready;
+ private @Nullable StreamObserver<WorkerMetadataResponse> responseObserver;
+
+ private GetWorkerMetadataTestStub(CountDownLatch ready) {
+ this.ready = ready;
+ }
+
+ @Override
+ public StreamObserver<WorkerMetadataRequest> getWorkerMetadataStream(
+ StreamObserver<WorkerMetadataResponse> responseObserver) {
+ if (this.responseObserver == null) {
+ ready.countDown();
+ this.responseObserver = responseObserver;
+ }
+
+ return new StreamObserver<WorkerMetadataRequest>() {
+ @Override
+ public void onNext(WorkerMetadataRequest workerMetadataRequest) {}
+
+ @Override
+ public void onError(Throwable throwable) {
+ if (responseObserver != null) {
+ responseObserver.onError(throwable);
+ }
+ }
+
+ @Override
+ public void onCompleted() {}
+ };
+ }
+
+ private void injectWorkerMetadata(WorkerMetadataResponse response) {
+ if (responseObserver != null) {
+ responseObserver.onNext(response);
+ }
+ }
+
+ private void close() {
+ if (responseObserver != null) {
+ // Send an empty response to close out all the streams and channels
currently open in
+ // Streaming Engine Client.
+ responseObserver.onNext(CLOSE_ALL_STREAMS);
+ }
+ }
+ }
+
+ private static class TestGetWorkBudgetDistributor implements
GetWorkBudgetDistributor {
+ @Override
+ public void distributeBudget(
+ ImmutableCollection<WindmillStreamSender> streams, GetWorkBudget
getWorkBudget) {
+ streams.forEach(stream -> stream.adjustBudget(getWorkBudget.items(),
getWorkBudget.bytes()));
+ }
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java
new file mode 100644
index 00000000000..c8d2974f923
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
+
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannel;
+import
org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessChannelBuilder;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.testing.GrpcCleanupRule;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class WindmillStreamSenderTest {
+ private static final GetWorkRequest GET_WORK_REQUEST =
+
GetWorkRequest.newBuilder().setClientId(1L).setJobId("job").setProjectId("project").build();
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+
+ private final GrpcWindmillStreamFactory streamFactory =
+ spy(
+ GrpcWindmillStreamFactory.of(
+ JobHeader.newBuilder()
+ .setJobId("job")
+ .setProjectId("project")
+ .setWorkerId("worker")
+ .build())
+ .build());
+ private final WorkItemProcessor workItemProcessor =
+ (computation,
+ inputDataWatermark,
+ synchronizedProcessingTime,
+ workItem,
+ ackQueuedWorkItem,
+ getWorkStreamLatencies) -> {};
+ private ManagedChannel inProcessChannel;
+ private CloudWindmillServiceV1Alpha1Stub stub;
+
+ @Before
+ public void setUp() {
+ inProcessChannel =
+ grpcCleanup.register(
+
InProcessChannelBuilder.forName("WindmillStreamSenderTest").directExecutor().build());
+ grpcCleanup.register(inProcessChannel);
+ stub = CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel);
+ }
+
+ @After
+ public void cleanUp() {
+ inProcessChannel.shutdownNow();
+ }
+
+ @Test
+ public void testStartStream_startsAllStreams() {
+ long itemBudget = 1L;
+ long byteBudget = 1L;
+
+ WindmillStreamSender windmillStreamSender =
+ newWindmillStreamSender(
+
GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build());
+
+ windmillStreamSender.startStreams();
+
+ verify(streamFactory)
+ .createDirectGetWorkStream(
+ eq(stub),
+ eq(
+ GET_WORK_REQUEST
+ .toBuilder()
+ .setMaxItems(itemBudget)
+ .setMaxBytes(byteBudget)
+ .build()),
+ any(ThrottleTimer.class),
+ any(),
+ any(),
+ eq(workItemProcessor));
+
+ verify(streamFactory).createGetDataStream(eq(stub),
any(ThrottleTimer.class));
+ verify(streamFactory).createCommitWorkStream(eq(stub),
any(ThrottleTimer.class));
+ }
+
+ @Test
+ public void testStartStream_onlyStartsStreamsOnce() {
+ long itemBudget = 1L;
+ long byteBudget = 1L;
+
+ WindmillStreamSender windmillStreamSender =
+ newWindmillStreamSender(
+
GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build());
+
+ windmillStreamSender.startStreams();
+ windmillStreamSender.startStreams();
+ windmillStreamSender.startStreams();
+
+ verify(streamFactory, times(1))
+ .createDirectGetWorkStream(
+ eq(stub),
+ eq(
+ GET_WORK_REQUEST
+ .toBuilder()
+ .setMaxItems(itemBudget)
+ .setMaxBytes(byteBudget)
+ .build()),
+ any(ThrottleTimer.class),
+ any(),
+ any(),
+ eq(workItemProcessor));
+
+ verify(streamFactory, times(1)).createGetDataStream(eq(stub),
any(ThrottleTimer.class));
+ verify(streamFactory, times(1)).createCommitWorkStream(eq(stub),
any(ThrottleTimer.class));
+ }
+
+ @Test
+ public void testStartStream_onlyStartsStreamsOnceConcurrent() throws
InterruptedException {
+ long itemBudget = 1L;
+ long byteBudget = 1L;
+
+ WindmillStreamSender windmillStreamSender =
+ newWindmillStreamSender(
+
GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build());
+
+ Thread startStreamThread = new Thread(windmillStreamSender::startStreams);
+ startStreamThread.start();
+
+ windmillStreamSender.startStreams();
+
+ startStreamThread.join();
+
+ verify(streamFactory, times(1))
+ .createDirectGetWorkStream(
+ eq(stub),
+ eq(
+ GET_WORK_REQUEST
+ .toBuilder()
+ .setMaxItems(itemBudget)
+ .setMaxBytes(byteBudget)
+ .build()),
+ any(ThrottleTimer.class),
+ any(),
+ any(),
+ eq(workItemProcessor));
+
+ verify(streamFactory, times(1)).createGetDataStream(eq(stub),
any(ThrottleTimer.class));
+ verify(streamFactory, times(1)).createCommitWorkStream(eq(stub),
any(ThrottleTimer.class));
+ }
+
+ @Test
+ public void testCloseAllStreams_doesNotCloseUnstartedStreams() {
+ WindmillStreamSender windmillStreamSender =
+
newWindmillStreamSender(GetWorkBudget.builder().setBytes(1L).setItems(1L).build());
+
+ windmillStreamSender.closeAllStreams();
+
+ verifyNoInteractions(streamFactory);
+ }
+
+ @Test
+ public void testCloseAllStreams_closesAllStreams() {
+ long itemBudget = 1L;
+ long byteBudget = 1L;
+ GetWorkRequest getWorkRequestWithBudget =
+
GET_WORK_REQUEST.toBuilder().setMaxItems(itemBudget).setMaxBytes(byteBudget).build();
+ GrpcWindmillStreamFactory mockStreamFactory =
mock(GrpcWindmillStreamFactory.class);
+ GetWorkStream mockGetWorkStream = mock(GetWorkStream.class);
+ GetDataStream mockGetDataStream = mock(GetDataStream.class);
+ CommitWorkStream mockCommitWorkStream = mock(CommitWorkStream.class);
+
+ when(mockStreamFactory.createDirectGetWorkStream(
+ eq(stub),
+ eq(getWorkRequestWithBudget),
+ any(ThrottleTimer.class),
+ any(),
+ any(),
+ eq(workItemProcessor)))
+ .thenReturn(mockGetWorkStream);
+
+ when(mockStreamFactory.createGetDataStream(eq(stub),
any(ThrottleTimer.class)))
+ .thenReturn(mockGetDataStream);
+ when(mockStreamFactory.createCommitWorkStream(eq(stub),
any(ThrottleTimer.class)))
+ .thenReturn(mockCommitWorkStream);
+
+ WindmillStreamSender windmillStreamSender =
+ newWindmillStreamSender(
+
GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build(),
+ mockStreamFactory);
+
+ windmillStreamSender.startStreams();
+ windmillStreamSender.closeAllStreams();
+
+ verify(mockGetWorkStream).close();
+ verify(mockGetDataStream).close();
+ verify(mockCommitWorkStream).close();
+ }
+
+ private WindmillStreamSender newWindmillStreamSender(GetWorkBudget budget) {
+ return newWindmillStreamSender(budget, streamFactory);
+ }
+
+ private WindmillStreamSender newWindmillStreamSender(
+ GetWorkBudget budget, GrpcWindmillStreamFactory streamFactory) {
+ return WindmillStreamSender.create(
+ stub, GET_WORK_REQUEST, budget, streamFactory, workItemProcessor);
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java
new file mode 100644
index 00000000000..14da55fe238
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
+
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+
+import java.util.ArrayList;
+import java.util.List;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.WindmillStreamSender;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannel;
+import
org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessChannelBuilder;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.testing.GrpcCleanupRule;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class EvenGetWorkBudgetDistributorTest {
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+
+ private ManagedChannel inProcessChannel;
+ private CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub
stub;
+
+ private static GetWorkBudgetDistributor
createBudgetDistributor(GetWorkBudget activeWorkBudget) {
+ return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget);
+ }
+
+ private static GetWorkBudgetDistributor createBudgetDistributor(long
activeWorkItemsAndBytes) {
+ return createBudgetDistributor(
+ GetWorkBudget.builder()
+ .setItems(activeWorkItemsAndBytes)
+ .setBytes(activeWorkItemsAndBytes)
+ .build());
+ }
+
+ @Before
+ public void setUp() {
+ inProcessChannel =
+ grpcCleanup.register(
+
InProcessChannelBuilder.forName("WindmillStreamSenderTest").directExecutor().build());
+ grpcCleanup.register(inProcessChannel);
+ stub = CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel);
+ }
+
+ @After
+ public void cleanUp() {
+ inProcessChannel.shutdownNow();
+ }
+
+ @Test
+ public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() {
+ createBudgetDistributor(1L)
+ .distributeBudget(
+ ImmutableList.of(),
GetWorkBudget.builder().setItems(10L).setBytes(10L).build());
+ }
+
+ @Test
+ public void testDistributeBudget_doesNothingWithNoBudget() {
+ WindmillStreamSender windmillStreamSender =
+ spy(createWindmillStreamSender(GetWorkBudget.noBudget()));
+ createBudgetDistributor(1L)
+ .distributeBudget(ImmutableList.of(windmillStreamSender),
GetWorkBudget.noBudget());
+ verifyNoInteractions(windmillStreamSender);
+ }
+
+ @Test
+ public void
testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighNoActiveWork()
{
+ WindmillStreamSender windmillStreamSender =
+ spy(
+ createWindmillStreamSender(
+ GetWorkBudget.builder().setItems(10L).setBytes(10L).build()));
+ createBudgetDistributor(0L)
+ .distributeBudget(
+ ImmutableList.of(windmillStreamSender),
+ GetWorkBudget.builder().setItems(10L).setBytes(10L).build());
+
+ verify(windmillStreamSender, never()).adjustBudget(anyLong(), anyLong());
+ }
+
+ @Test
+ public void
+
testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighWithActiveWork()
{
+ WindmillStreamSender windmillStreamSender =
+
spy(createWindmillStreamSender(GetWorkBudget.builder().setItems(5L).setBytes(5L).build()));
+ createBudgetDistributor(10L)
+ .distributeBudget(
+ ImmutableList.of(windmillStreamSender),
+ GetWorkBudget.builder().setItems(20L).setBytes(20L).build());
+
+ verify(windmillStreamSender, never()).adjustBudget(anyLong(), anyLong());
+ }
+
+ @Test
+ public void
+
testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithNoActiveWork()
{
+ GetWorkBudget streamRemainingBudget =
+ GetWorkBudget.builder().setItems(1L).setBytes(10L).build();
+ GetWorkBudget totalGetWorkBudget =
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
+ WindmillStreamSender windmillStreamSender =
+ spy(createWindmillStreamSender(streamRemainingBudget));
+ createBudgetDistributor(0L)
+ .distributeBudget(ImmutableList.of(windmillStreamSender),
totalGetWorkBudget);
+
+ verify(windmillStreamSender, times(1))
+ .adjustBudget(
+ eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
+ eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
+ }
+
+ @Test
+ public void
+
testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithActiveWork()
{
+ GetWorkBudget streamRemainingBudget =
+ GetWorkBudget.builder().setItems(1L).setBytes(10L).build();
+ GetWorkBudget totalGetWorkBudget =
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
+ long activeWorkItemsAndBytes = 2L;
+ WindmillStreamSender windmillStreamSender =
+ spy(createWindmillStreamSender(streamRemainingBudget));
+ createBudgetDistributor(activeWorkItemsAndBytes)
+ .distributeBudget(ImmutableList.of(windmillStreamSender),
totalGetWorkBudget);
+
+ verify(windmillStreamSender, times(1))
+ .adjustBudget(
+ eq(
+ totalGetWorkBudget.items()
+ - streamRemainingBudget.items()
+ - activeWorkItemsAndBytes),
+ eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
+ }
+
+ @Test
+ public void
testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowNoActiveWork()
{
+ GetWorkBudget streamRemainingBudget =
+ GetWorkBudget.builder().setItems(10L).setBytes(1L).build();
+ GetWorkBudget totalGetWorkBudget =
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
+ WindmillStreamSender windmillStreamSender =
+ spy(createWindmillStreamSender(streamRemainingBudget));
+ createBudgetDistributor(0L)
+ .distributeBudget(ImmutableList.of(windmillStreamSender),
totalGetWorkBudget);
+
+ verify(windmillStreamSender, times(1))
+ .adjustBudget(
+ eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
+ eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
+ }
+
+ @Test
+ public void
+
testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowWithActiveWork()
{
+ GetWorkBudget streamRemainingBudget =
+ GetWorkBudget.builder().setItems(10L).setBytes(1L).build();
+ GetWorkBudget totalGetWorkBudget =
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
+ long activeWorkItemsAndBytes = 2L;
+
+ WindmillStreamSender windmillStreamSender =
+ spy(createWindmillStreamSender(streamRemainingBudget));
+ createBudgetDistributor(activeWorkItemsAndBytes)
+ .distributeBudget(ImmutableList.of(windmillStreamSender),
totalGetWorkBudget);
+
+ verify(windmillStreamSender, times(1))
+ .adjustBudget(
+ eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
+ eq(
+ totalGetWorkBudget.bytes()
+ - streamRemainingBudget.bytes()
+ - activeWorkItemsAndBytes));
+ }
+
+ @Test
+ public void testDistributeBudget_distributesBudgetEvenlyIfPossible() {
+ long totalItemsAndBytes = 10L;
+ List<WindmillStreamSender> streams = new ArrayList<>();
+ for (int i = 0; i < totalItemsAndBytes; i++) {
+ streams.add(spy(createWindmillStreamSender(GetWorkBudget.noBudget())));
+ }
+ createBudgetDistributor(0L)
+ .distributeBudget(
+ ImmutableList.copyOf(streams),
+ GetWorkBudget.builder()
+ .setItems(totalItemsAndBytes)
+ .setBytes(totalItemsAndBytes)
+ .build());
+
+ long itemsAndBytesPerStream = totalItemsAndBytes / streams.size();
+ streams.forEach(
+ stream ->
+ verify(stream, times(1))
+ .adjustBudget(eq(itemsAndBytesPerStream),
eq(itemsAndBytesPerStream)));
+ }
+
+ @Test
+ public void testDistributeBudget_distributesFairlyWhenNotEven() {
+ long totalItemsAndBytes = 10L;
+ List<WindmillStreamSender> streams = new ArrayList<>();
+ for (int i = 0; i < 3; i++) {
+ streams.add(spy(createWindmillStreamSender(GetWorkBudget.noBudget())));
+ }
+ createBudgetDistributor(0L)
+ .distributeBudget(
+ ImmutableList.copyOf(streams),
+ GetWorkBudget.builder()
+ .setItems(totalItemsAndBytes)
+ .setBytes(totalItemsAndBytes)
+ .build());
+
+ long itemsAndBytesPerStream = (long) Math.ceil(totalItemsAndBytes /
(streams.size() * 1.0));
+ streams.forEach(
+ stream ->
+ verify(stream, times(1))
+ .adjustBudget(eq(itemsAndBytesPerStream),
eq(itemsAndBytesPerStream)));
+ }
+
+ private WindmillStreamSender createWindmillStreamSender(GetWorkBudget
getWorkBudget) {
+ return WindmillStreamSender.create(
+ stub,
+ Windmill.GetWorkRequest.newBuilder()
+ .setClientId(1L)
+ .setJobId("job")
+ .setProjectId("project")
+ .build(),
+ getWorkBudget,
+ GrpcWindmillStreamFactory.of(
+ JobHeader.newBuilder()
+ .setJobId("job")
+ .setProjectId("project")
+ .setWorkerId("worker")
+ .build())
+ .build(),
+ (computation,
+ inputDataWatermark,
+ synchronizedProcessingTime,
+ workItem,
+ ackQueuedWorkItem,
+ getWorkStreamLatencies) -> {});
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresherTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresherTest.java
new file mode 100644
index 00000000000..fd85410cc91
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresherTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
+
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mockito;
+
+@RunWith(JUnit4.class)
+public class GetWorkBudgetRefresherTest {
+ private static final int WAIT_BUFFER = 10;
+ private final Runnable redistributeBudget = Mockito.mock(Runnable.class);
+
+ private GetWorkBudgetRefresher createBudgetRefresher() {
+ return createBudgetRefresher(false);
+ }
+
+ private GetWorkBudgetRefresher createBudgetRefresher(Boolean
isBudgetRefreshPaused) {
+ return new GetWorkBudgetRefresher(() -> isBudgetRefreshPaused,
redistributeBudget);
+ }
+
+ @Test
+ public void testStop_successfullyTerminates() throws InterruptedException {
+ GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher();
+ budgetRefresher.start();
+ budgetRefresher.stop();
+ budgetRefresher.requestBudgetRefresh();
+ Thread.sleep(WAIT_BUFFER);
+ verifyNoInteractions(redistributeBudget);
+ }
+
+ @Test
+ public void testRequestBudgetRefresh_triggersBudgetRefresh() throws
InterruptedException {
+ GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher();
+ budgetRefresher.start();
+ budgetRefresher.requestBudgetRefresh();
+ // Wait a bit for redistribute budget to run.
+ Thread.sleep(WAIT_BUFFER);
+ verify(redistributeBudget, times(1)).run();
+ }
+
+ @Test
+ public void testScheduledBudgetRefresh() throws InterruptedException {
+ GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher();
+ budgetRefresher.start();
+ Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS +
WAIT_BUFFER);
+ verify(redistributeBudget, times(1)).run();
+ }
+
+ @Test
+ public void testTriggeredAndScheduledBudgetRefresh_concurrent() throws
InterruptedException {
+ GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher();
+ budgetRefresher.start();
+ Thread budgetRefreshTriggerThread = new
Thread(budgetRefresher::requestBudgetRefresh);
+ budgetRefreshTriggerThread.start();
+ Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS +
WAIT_BUFFER);
+ budgetRefreshTriggerThread.join();
+
+ // Wait a bit for redistribute budget to run.
+ Thread.sleep(WAIT_BUFFER);
+ verify(redistributeBudget, times(2)).run();
+ }
+
+ @Test
+ public void testTriggeredBudgetRefresh_doesNotRunWhenBudgetRefreshPaused()
+ throws InterruptedException {
+ GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(true);
+ budgetRefresher.start();
+ budgetRefresher.requestBudgetRefresh();
+ Thread.sleep(WAIT_BUFFER);
+ verifyNoInteractions(redistributeBudget);
+ }
+
+ @Test
+ public void testScheduledBudgetRefresh_doesNotRunWhenBudgetRefreshPaused()
+ throws InterruptedException {
+ GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(true);
+ budgetRefresher.start();
+ Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS +
WAIT_BUFFER);
+ verifyNoInteractions(redistributeBudget);
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetTest.java
index 76d50839785..97789abaaa9 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetTest.java
@@ -18,7 +18,6 @@
package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertThrows;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -42,31 +41,10 @@ public class GetWorkBudgetTest {
}
@Test
- public void testAdd_doesNotAllowNegativeParameters() {
+ public void testApply_itemsAndBytesNeverBelowZero() {
GetWorkBudget getWorkBudget =
GetWorkBudget.builder().setItems(1).setBytes(1).build();
- assertThrows(IllegalArgumentException.class, () -> getWorkBudget.add(-1,
-1));
- }
-
- @Test
- public void testSubtract_itemsAndBytesNeverBelowZero() {
- GetWorkBudget getWorkBudget =
GetWorkBudget.builder().setItems(1).setBytes(1).build();
- GetWorkBudget subtracted = getWorkBudget.subtract(10, 10);
- assertEquals(0, subtracted.items());
- assertEquals(0, subtracted.bytes());
- }
-
- @Test
- public void testSubtractGetWorkBudget_itemsAndBytesNeverBelowZero() {
- GetWorkBudget getWorkBudget =
GetWorkBudget.builder().setItems(1).setBytes(1).build();
- GetWorkBudget subtracted =
-
getWorkBudget.subtract(GetWorkBudget.builder().setItems(10).setBytes(10).build());
+ GetWorkBudget subtracted = getWorkBudget.apply(-10, -10);
assertEquals(0, subtracted.items());
assertEquals(0, subtracted.bytes());
}
-
- @Test
- public void testSubtract_doesNotAllowNegativeParameters() {
- GetWorkBudget getWorkBudget =
GetWorkBudget.builder().setItems(1).setBytes(1).build();
- assertThrows(IllegalArgumentException.class, () ->
getWorkBudget.subtract(-1, -1));
- }
}