This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 690a5a44605 Update windmill proto definition (#30046)
690a5a44605 is described below
commit 690a5a446054281f72fa4c141c8e90e296663c70
Author: martin trieu <[email protected]>
AuthorDate: Thu Feb 15 11:56:46 2024 -0800
Update windmill proto definition (#30046)
* make external and internal windmill proto defs identical
* override authority of grpc channel for direct path
* Do not inject WindmillServer from deps, create in StreamingDataflowWorker
---
.../dataflow/worker/StreamingDataflowWorker.java | 96 +++++++----
.../options/StreamingDataflowWorkerOptions.java | 33 ----
.../streaming/WorkHeartbeatResponseProcessor.java | 68 ++++++++
.../worker/windmill/WindmillEndpoints.java | 178 ++++++++++++---------
.../worker/windmill/WindmillServerStub.java | 6 -
.../worker/windmill/WindmillServiceAddress.java | 28 +++-
.../windmill/client/grpc/GrpcDispatcherClient.java | 178 +++++++++++++++------
.../windmill/client/grpc/GrpcWindmillServer.java | 110 +++++++------
.../client/grpc/GrpcWindmillStreamFactory.java | 17 +-
.../client/grpc/StreamingEngineClient.java | 23 +--
.../grpc/stubs/RemoteWindmillStubFactory.java | 76 +++++++++
.../client/grpc/stubs/WindmillChannelFactory.java | 40 ++++-
.../client/grpc/stubs/WindmillStubFactory.java | 62 +------
.../dataflow/worker/FakeWindmillServer.java | 19 ++-
.../worker/StreamingDataflowWorkerTest.java | 161 ++++++++-----------
.../grpc/GrpcGetWorkerMetadataStreamTest.java | 30 ++--
.../client/grpc/GrpcWindmillServerTest.java | 29 +++-
.../client/grpc/StreamingEngineClientTest.java | 48 ++++--
.../windmill/testing/FakeWindmillStubFactory.java | 47 ++++++
.../worker/windmill/src/main/proto/windmill.proto | 12 +-
.../windmill/src/main/proto/windmill_service.proto | 10 +-
21 files changed, 806 insertions(+), 465 deletions(-)
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 06450e60fc0..825c3fb78c7 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -57,6 +57,7 @@ import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.servlet.http.HttpServletRequest;
@@ -96,6 +97,7 @@ import
org.apache.beam.runners.dataflow.worker.streaming.StageInfo;
import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.streaming.Work.State;
+import
org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor;
import org.apache.beam.runners.dataflow.worker.streaming.WorkId;
import
org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher;
import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
@@ -104,13 +106,16 @@ import
org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter
import
org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter;
import
org.apache.beam.runners.dataflow.worker.util.common.worker.ReadOperation;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
-import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
+import
org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory;
import
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader;
import org.apache.beam.sdk.coders.Coder;
@@ -208,7 +213,7 @@ public class StreamingDataflowWorker {
private static final Random clientIdGenerator = new Random();
final WindmillStateCache stateCache;
// Maps from computation ids to per-computation state.
- private final ConcurrentMap<String, ComputationState> computationMap = new
ConcurrentHashMap<>();
+ private final ConcurrentMap<String, ComputationState> computationMap;
private final WeightedBoundedQueue<Commit> commitQueue =
WeightedBoundedQueue.create(
MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES,
commit.getSize()));
@@ -280,8 +285,7 @@ public class StreamingDataflowWorker {
// Periodic sender of debug information to the debug capture service.
private final DebugCapture.@Nullable Manager debugCaptureManager;
// Collection of ScheduledExecutorServices that are running periodic
functions.
- private ArrayList<ScheduledExecutorService> scheduledExecutors =
- new ArrayList<ScheduledExecutorService>();
+ private final ArrayList<ScheduledExecutorService> scheduledExecutors = new
ArrayList<>();
private int retryLocallyDelayMs = 10000;
// Periodically fires a global config request to dataflow service. Only used
when windmill service
// is enabled.
@@ -292,6 +296,9 @@ public class StreamingDataflowWorker {
@VisibleForTesting
StreamingDataflowWorker(
+ WindmillServerStub windmillServer,
+ long clientId,
+ ConcurrentMap<String, ComputationState> computationMap,
List<MapTask> mapTasks,
DataflowMapTaskExecutorFactory mapTaskExecutorFactory,
WorkUnitClient workUnitClient,
@@ -299,13 +306,13 @@ public class StreamingDataflowWorker {
boolean publishCounters,
HotKeyLogger hotKeyLogger,
Supplier<Instant> clock,
- Function<String, ScheduledExecutorService> executorSupplier)
- throws IOException {
+ Function<String, ScheduledExecutorService> executorSupplier) {
this.stateCache = new WindmillStateCache(options.getWorkerCacheMb());
this.readerCache =
new ReaderCache(
Duration.standardSeconds(options.getReaderCacheTimeoutSec()),
Executors.newCachedThreadPool());
+ this.computationMap = computationMap;
this.mapTaskExecutorFactory = mapTaskExecutorFactory;
this.workUnitClient = workUnitClient;
this.options = options;
@@ -429,8 +436,8 @@ public class StreamingDataflowWorker {
commitThreads = commitThreadsBuilder.build();
this.publishCounters = publishCounters;
- this.windmillServer = options.getWindmillServerStub();
-
this.windmillServer.setProcessHeartbeatResponses(this::handleHeartbeatResponses);
+ this.clientId = clientId;
+ this.windmillServer = windmillServer;
this.metricTrackingWindmillServer =
MetricTrackingWindmillServerStub.builder(windmillServer, memoryMonitor)
.setUseStreamingRequests(windmillServiceEnabled)
@@ -438,7 +445,6 @@ public class StreamingDataflowWorker {
.setNumGetDataStreams(options.getWindmillGetDataStreamCount())
.build();
this.sideInputStateFetcher = new
SideInputStateFetcher(metricTrackingWindmillServer, options);
- this.clientId = clientIdGenerator.nextLong();
for (MapTask mapTask : mapTasks) {
addComputation(mapTask.getSystemName(), mapTask, ImmutableMap.of());
@@ -456,6 +462,44 @@ public class StreamingDataflowWorker {
LOG.debug("maxWorkItemCommitBytes: {}", maxWorkItemCommitBytes);
}
+ private static WindmillServerStub createWindmillServerStub(
+ StreamingDataflowWorkerOptions options,
+ long clientId,
+ Consumer<List<Windmill.ComputationHeartbeatResponse>>
processHeartbeatResponses) {
+ if (options.getWindmillServiceEndpoint() != null
+ || options.isEnableStreamingEngine()
+ || options.getLocalWindmillHostport().startsWith("grpc:")) {
+ try {
+ Duration maxBackoff =
+ !options.isEnableStreamingEngine() &&
options.getLocalWindmillHostport() != null
+ ? GrpcWindmillServer.LOCALHOST_MAX_BACKOFF
+ : GrpcWindmillServer.MAX_BACKOFF;
+ GrpcWindmillStreamFactory windmillStreamFactory =
+ GrpcWindmillStreamFactory.of(
+ JobHeader.newBuilder()
+ .setJobId(options.getJobId())
+ .setProjectId(options.getProject())
+ .setWorkerId(options.getWorkerId())
+ .setClientId(clientId)
+ .build())
+ .setWindmillMessagesBetweenIsReadyChecks(
+ options.getWindmillMessagesBetweenIsReadyChecks())
+ .setMaxBackOffSupplier(() -> maxBackoff)
+ .setLogEveryNStreamFailures(
+
options.getWindmillServiceStreamingLogEveryNStreamFailures())
+
.setStreamingRpcBatchLimit(options.getWindmillServiceStreamingRpcBatchLimit())
+ .build();
+ windmillStreamFactory.scheduleHealthChecks(
+ options.getWindmillServiceStreamingRpcHealthCheckPeriodMs());
+ return GrpcWindmillServer.create(options, windmillStreamFactory,
processHeartbeatResponses);
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to create GrpcWindmillServer: ", e);
+ }
+ } else {
+ return new
JniWindmillApplianceServer(options.getLocalWindmillHostport());
+ }
+ }
+
/** Returns whether an exception was caused by a {@link OutOfMemoryError}. */
private static boolean isOutOfMemoryError(Throwable t) {
while (t != null) {
@@ -509,10 +553,17 @@ public class StreamingDataflowWorker {
worker.start();
}
- public static StreamingDataflowWorker
fromOptions(StreamingDataflowWorkerOptions options)
- throws IOException {
-
+ public static StreamingDataflowWorker
fromOptions(StreamingDataflowWorkerOptions options) {
+ ConcurrentMap<String, ComputationState> computationMap = new
ConcurrentHashMap<>();
+ long clientId = clientIdGenerator.nextLong();
return new StreamingDataflowWorker(
+ createWindmillServerStub(
+ options,
+ clientId,
+ new WorkHeartbeatResponseProcessor(
+ computationId ->
Optional.ofNullable(computationMap.get(computationId)))),
+ clientId,
+ computationMap,
Collections.emptyList(),
IntrinsicMapTaskExecutorFactory.defaultFactory(),
new DataflowWorkUnitClient(options, LOG),
@@ -1626,7 +1677,6 @@ public class StreamingDataflowWorker {
@SuppressWarnings("FutureReturnValueIgnored")
private void schedulePeriodicGlobalConfigRequests() {
Preconditions.checkState(windmillServiceEnabled);
-
if (!windmillServer.isReady()) {
// Get the initial global configuration. This will initialize the
windmillServer stub.
while (true) {
@@ -1975,26 +2025,6 @@ public class StreamingDataflowWorker {
}
}
- public void handleHeartbeatResponses(List<ComputationHeartbeatResponse>
responses) {
- for (ComputationHeartbeatResponse computationHeartbeatResponse :
responses) {
- // Maps sharding key to (work token, cache token) for work that should
be marked failed.
- Multimap<Long, WorkId> failedWork = ArrayListMultimap.create();
- for (Windmill.HeartbeatResponse heartbeatResponse :
- computationHeartbeatResponse.getHeartbeatResponsesList()) {
- if (heartbeatResponse.getFailed()) {
- failedWork.put(
- heartbeatResponse.getShardingKey(),
- WorkId.builder()
- .setWorkToken(heartbeatResponse.getWorkToken())
- .setCacheToken(heartbeatResponse.getCacheToken())
- .build());
- }
- }
- ComputationState state =
computationMap.get(computationHeartbeatResponse.getComputationId());
- if (state != null) state.failWork(failedWork);
- }
- }
-
/**
* Sends a GetData request to Windmill for all sufficiently old active work.
*
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java
index a75a60af2ba..9431470a16f 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java
@@ -17,11 +17,7 @@
*/
package org.apache.beam.runners.dataflow.worker.options;
-import java.io.IOException;
import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions;
-import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
-import
org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer;
-import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer;
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.DefaultValueFactory;
import org.apache.beam.sdk.options.Description;
@@ -36,12 +32,6 @@ import org.joda.time.Duration;
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public interface StreamingDataflowWorkerOptions extends
DataflowWorkerHarnessOptions {
- @Description("Stub for communicating with Windmill.")
- @Default.InstanceFactory(WindmillServerStubFactory.class)
- WindmillServerStub getWindmillServerStub();
-
- void setWindmillServerStub(WindmillServerStub value);
-
@Description("Hostport of a co-located Windmill server.")
@Default.InstanceFactory(LocalWindmillHostportFactory.class)
String getLocalWindmillHostport();
@@ -168,29 +158,6 @@ public interface StreamingDataflowWorkerOptions extends
DataflowWorkerHarnessOpt
}
}
- /**
- * Factory for creating {@link WindmillServerStub} instances. If {@link
setLocalWindmillHostport}
- * is set, returns a stub to a local Windmill server, otherwise returns a
remote gRPC stub.
- */
- public static class WindmillServerStubFactory implements
DefaultValueFactory<WindmillServerStub> {
- @Override
- public WindmillServerStub create(PipelineOptions options) {
- StreamingDataflowWorkerOptions streamingOptions =
- options.as(StreamingDataflowWorkerOptions.class);
- if (streamingOptions.getWindmillServiceEndpoint() != null
- || streamingOptions.isEnableStreamingEngine()
- || streamingOptions.getLocalWindmillHostport().startsWith("grpc:")) {
- try {
- return GrpcWindmillServer.create(streamingOptions);
- } catch (IOException e) {
- throw new RuntimeException("Failed to create GrpcWindmillServer: ",
e);
- }
- } else {
- return new
JniWindmillApplianceServer(streamingOptions.getLocalWindmillHostport());
- }
- }
- }
-
/** Factory for setting value of WindmillServiceStreamingRpcBatchLimit based
on environment. */
public static class WindmillServiceStreamingRpcBatchLimitFactory
implements DefaultValueFactory<Integer> {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkHeartbeatResponseProcessor.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkHeartbeatResponseProcessor.java
new file mode 100644
index 00000000000..341f434cefa
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkHeartbeatResponseProcessor.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatResponse;
+import org.apache.beam.sdk.annotations.Internal;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ArrayListMultimap;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
+
+/**
+ * Processes {@link ComputationHeartbeatResponse}(s). Marks {@link Work} that
is invalid from
+ * Streaming Engine backend so that it gets dropped from streaming worker
harness processing.
+ */
+@Internal
+public final class WorkHeartbeatResponseProcessor
+ implements Consumer<List<ComputationHeartbeatResponse>> {
+ /** Fetches a {@link ComputationState} for a computationId. */
+ private final Function<String, Optional<ComputationState>>
computationStateFetcher;
+
+ public WorkHeartbeatResponseProcessor(
+ /* Fetches a {@link ComputationState} for a String computationId. */
+ Function<String, Optional<ComputationState>> computationStateFetcher) {
+ this.computationStateFetcher = computationStateFetcher;
+ }
+
+ @Override
+ public void accept(List<ComputationHeartbeatResponse> responses) {
+ for (ComputationHeartbeatResponse computationHeartbeatResponse :
responses) {
+ // Maps sharding key to (work token, cache token) for work that should
be marked failed.
+ Multimap<Long, WorkId> failedWork = ArrayListMultimap.create();
+ for (HeartbeatResponse heartbeatResponse :
+ computationHeartbeatResponse.getHeartbeatResponsesList()) {
+ if (heartbeatResponse.getFailed()) {
+ failedWork.put(
+ heartbeatResponse.getShardingKey(),
+ WorkId.builder()
+ .setWorkToken(heartbeatResponse.getWorkToken())
+ .setCacheToken(heartbeatResponse.getCacheToken())
+ .build());
+ }
+ }
+
+ computationStateFetcher
+ .apply(computationHeartbeatResponse.getComputationId())
+ .ifPresent(state -> state.failWork(failedWork));
+ }
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java
index 64b6e675ef5..d7ed83def43 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java
@@ -26,6 +26,7 @@ import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Map;
import java.util.Optional;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
@@ -40,23 +41,6 @@ import org.slf4j.LoggerFactory;
public abstract class WindmillEndpoints {
private static final Logger LOG =
LoggerFactory.getLogger(WindmillEndpoints.class);
- /**
- * Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns
a map where the key
- * is a global data tag and the value is the endpoint where the data
associated with the global
- * data tag resides.
- *
- * @see <a
href="https://beam.apache.org/documentation/programming-guide/#side-inputs">Beam
Side
- * Inputs</a>
- */
- public abstract ImmutableMap<String, Endpoint> globalDataEndpoints();
-
- /**
- * Used by GetWork/GetData/CommitWork calls to send, receive, and commit
work directly to/from
- * Windmill servers. Returns a list of endpoints used to communicate with
the corresponding
- * Windmill servers.
- */
- public abstract ImmutableList<Endpoint> windmillEndpoints();
-
public static WindmillEndpoints from(
Windmill.WorkerMetadataResponse workerMetadataResponseProto) {
ImmutableMap<String, WindmillEndpoints.Endpoint> globalDataServers =
@@ -64,11 +48,16 @@ public abstract class WindmillEndpoints {
.collect(
toImmutableMap(
Map.Entry::getKey, // global data key
- endpoint ->
WindmillEndpoints.Endpoint.from(endpoint.getValue())));
+ endpoint ->
+ WindmillEndpoints.Endpoint.from(
+ endpoint.getValue(),
+
workerMetadataResponseProto.getExternalEndpoint())));
ImmutableList<WindmillEndpoints.Endpoint> windmillServers =
workerMetadataResponseProto.getWorkEndpointsList().stream()
- .map(WindmillEndpoints.Endpoint::from)
+ .map(
+ endpointProto ->
+ Endpoint.from(endpointProto,
workerMetadataResponseProto.getExternalEndpoint()))
.collect(toImmutableList());
return WindmillEndpoints.builder()
@@ -81,6 +70,76 @@ public abstract class WindmillEndpoints {
return new AutoValue_WindmillEndpoints.Builder();
}
+ private static Optional<WindmillServiceAddress> parseDirectEndpoint(
+ Windmill.WorkerMetadataResponse.Endpoint endpointProto, String
authenticatingService) {
+ Optional<WindmillServiceAddress> directEndpointIpV6Address =
+ tryParseDirectEndpointIntoIpV6Address(endpointProto)
+ .map(address ->
AuthenticatedGcpServiceAddress.create(authenticatingService, address))
+ .map(WindmillServiceAddress::create);
+
+ return directEndpointIpV6Address.isPresent()
+ ? directEndpointIpV6Address
+ : tryParseEndpointIntoHostAndPort(endpointProto.getDirectEndpoint())
+ .map(WindmillServiceAddress::create);
+ }
+
+ private static Optional<HostAndPort> tryParseEndpointIntoHostAndPort(String
directEndpoint) {
+ try {
+ return Optional.of(HostAndPort.fromString(directEndpoint));
+ } catch (IllegalArgumentException e) {
+ LOG.warn("{} cannot be parsed into a gcpServiceAddress", directEndpoint);
+ return Optional.empty();
+ }
+ }
+
+ private static Optional<HostAndPort> tryParseDirectEndpointIntoIpV6Address(
+ Windmill.WorkerMetadataResponse.Endpoint endpointProto) {
+ if (!endpointProto.hasDirectEndpoint()) {
+ return Optional.empty();
+ }
+
+ InetAddress directEndpointAddress;
+ try {
+ directEndpointAddress =
Inet6Address.getByName(endpointProto.getDirectEndpoint());
+ } catch (UnknownHostException e) {
+ LOG.warn(
+ "Error occurred trying to parse direct_endpoint={} into IPv6
address. Exception={}",
+ endpointProto.getDirectEndpoint(),
+ e.toString());
+ return Optional.empty();
+ }
+
+ // Inet6Address.getByAddress returns either an IPv4 or an IPv6 address
depending on the format
+ // of the direct_endpoint string.
+ if (!(directEndpointAddress instanceof Inet6Address)) {
+ LOG.warn(
+ "{} is not an IPv6 address. Direct endpoints are expected to be in
IPv6 format.",
+ endpointProto.getDirectEndpoint());
+ return Optional.empty();
+ }
+
+ return Optional.of(
+ HostAndPort.fromParts(
+ directEndpointAddress.getHostAddress(), (int)
endpointProto.getPort()));
+ }
+
+ /**
+ * Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns
a map where the key
+ * is a global data tag and the value is the endpoint where the data
associated with the global
+ * data tag resides.
+ *
+ * @see <a
href="https://beam.apache.org/documentation/programming-guide/#side-inputs">Beam
Side
+ * Inputs</a>
+ */
+ public abstract ImmutableMap<String, Endpoint> globalDataEndpoints();
+
+ /**
+ * Used by GetWork/GetData/CommitWork calls to send, receive, and commit
work directly to/from
+ * Windmill servers. Returns a list of endpoints used to communicate with
the corresponding
+ * Windmill servers.
+ */
+ public abstract ImmutableList<Endpoint> windmillEndpoints();
+
/**
* Representation of an endpoint in {@link
Windmill.WorkerMetadataResponse.Endpoint} proto with
* the worker_token field, and direct_endpoint field parsed into a {@link
WindmillServiceAddress}
@@ -90,31 +149,21 @@ public abstract class WindmillEndpoints {
*/
@AutoValue
public abstract static class Endpoint {
- /**
- * {@link WindmillServiceAddress} representation of {@link
- * Windmill.WorkerMetadataResponse.Endpoint#getDirectEndpoint()}. The
proto's direct_endpoint
- * string can be converted to either {@link Inet6Address} or {@link
HostAndPort}.
- */
- public abstract Optional<WindmillServiceAddress> directEndpoint();
-
- /**
- * Corresponds to {@link
Windmill.WorkerMetadataResponse.Endpoint#getWorkerToken()} in the
- * windmill.proto file.
- */
- public abstract Optional<String> workerToken();
-
public static Endpoint.Builder builder() {
return new AutoValue_WindmillEndpoints_Endpoint.Builder();
}
- public static Endpoint from(Windmill.WorkerMetadataResponse.Endpoint
endpointProto) {
+ public static Endpoint from(
+ Windmill.WorkerMetadataResponse.Endpoint endpointProto, String
authenticatingService) {
Endpoint.Builder endpointBuilder = Endpoint.builder();
- if (endpointProto.hasDirectEndpoint() &&
!endpointProto.getDirectEndpoint().isEmpty()) {
- parseDirectEndpoint(endpointProto.getDirectEndpoint())
+
+ if (!endpointProto.getDirectEndpoint().isEmpty()) {
+ parseDirectEndpoint(endpointProto, authenticatingService)
.ifPresent(endpointBuilder::setDirectEndpoint);
}
- if (endpointProto.hasWorkerToken() &&
!endpointProto.getWorkerToken().isEmpty()) {
- endpointBuilder.setWorkerToken(endpointProto.getWorkerToken());
+
+ if (!endpointProto.getBackendWorkerToken().isEmpty()) {
+ endpointBuilder.setWorkerToken(endpointProto.getBackendWorkerToken());
}
Endpoint endpoint = endpointBuilder.build();
@@ -130,6 +179,19 @@ public abstract class WindmillEndpoints {
return endpoint;
}
+ /**
+ * {@link WindmillServiceAddress} representation of {@link
+ * Windmill.WorkerMetadataResponse.Endpoint#getDirectEndpoint()}. The
proto's direct_endpoint
+ * string can be converted to either {@link Inet6Address} or {@link
HostAndPort}.
+ */
+ public abstract Optional<WindmillServiceAddress> directEndpoint();
+
+ /**
+ * Corresponds to {@link
Windmill.WorkerMetadataResponse.Endpoint#getBackendWorkerToken()} ()}
+ * in the windmill.proto file.
+ */
+ public abstract Optional<String> workerToken();
+
@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setDirectEndpoint(WindmillServiceAddress
directEndpoint);
@@ -176,46 +238,4 @@ public abstract class WindmillEndpoints {
public abstract WindmillEndpoints build();
}
-
- private static Optional<WindmillServiceAddress> parseDirectEndpoint(String
directEndpoint) {
- Optional<WindmillServiceAddress> directEndpointIpV6Address =
-
tryParseDirectEndpointIntoIpV6Address(directEndpoint).map(WindmillServiceAddress::create);
-
- return directEndpointIpV6Address.isPresent()
- ? directEndpointIpV6Address
- :
tryParseEndpointIntoHostAndPort(directEndpoint).map(WindmillServiceAddress::create);
- }
-
- private static Optional<HostAndPort> tryParseEndpointIntoHostAndPort(String
directEndpoint) {
- try {
- return Optional.of(HostAndPort.fromString(directEndpoint));
- } catch (IllegalArgumentException e) {
- LOG.warn("{} cannot be parsed into a gcpServiceAddress", directEndpoint);
- return Optional.empty();
- }
- }
-
- private static Optional<Inet6Address> tryParseDirectEndpointIntoIpV6Address(
- String directEndpoint) {
- InetAddress directEndpointAddress = null;
- try {
- directEndpointAddress = Inet6Address.getByName(directEndpoint);
- } catch (UnknownHostException e) {
- LOG.warn(
- "Error occurred trying to parse direct_endpoint={} into IPv6
address. Exception={}",
- directEndpoint,
- e.toString());
- }
-
- // Inet6Address.getByAddress returns either an IPv4 or an IPv6 address
depending on the format
- // of the direct_endpoint string.
- if (!(directEndpointAddress instanceof Inet6Address)) {
- LOG.warn(
- "{} is not an IPv6 address. Direct endpoints are expected to be in
IPv6 format.",
- directEndpoint);
- return Optional.empty();
- }
-
- return Optional.ofNullable((Inet6Address) directEndpointAddress);
- }
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
index 25581bee208..c327e68d7e9 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
@@ -19,11 +19,8 @@ package org.apache.beam.runners.dataflow.worker.windmill;
import java.io.IOException;
import java.io.PrintWriter;
-import java.util.List;
import java.util.Set;
-import java.util.function.Consumer;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
-import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
@@ -82,9 +79,6 @@ public abstract class WindmillServerStub implements
StatusDataProvider {
@Override
public void appendSummaryHtml(PrintWriter writer) {}
- public void setProcessHeartbeatResponses(
- Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses)
{}
-
/** Generic Exception type for implementors to use to represent errors while
making RPCs. */
public static final class RpcException extends RuntimeException {
public RpcException(Throwable cause) {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java
index 3ebda8fab8e..90f93b07267 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java
@@ -18,6 +18,7 @@
package org.apache.beam.runners.dataflow.worker.windmill;
import com.google.auto.value.AutoOneOf;
+import com.google.auto.value.AutoValue;
import java.net.Inet6Address;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
@@ -38,8 +39,33 @@ public abstract class WindmillServiceAddress {
public abstract HostAndPort gcpServiceAddress();
+ public abstract AuthenticatedGcpServiceAddress
authenticatedGcpServiceAddress();
+
+ public static WindmillServiceAddress create(
+ AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) {
+ return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress(
+ authenticatedGcpServiceAddress);
+ }
+
public enum Kind {
IPV6,
- GCP_SERVICE_ADDRESS
+ GCP_SERVICE_ADDRESS,
+ // TODO(m-trieu): Use for direct connections when ALTS is enabled.
+ AUTHENTICATED_GCP_SERVICE_ADDRESS
+ }
+
+ @AutoValue
+ public abstract static class AuthenticatedGcpServiceAddress {
+
+ public static AuthenticatedGcpServiceAddress create(
+ String authenticatingService, HostAndPort gcpServiceAddress) {
+ // HostAndPort supports IpV6.
+ return new
AutoValue_WindmillServiceAddress_AuthenticatedGcpServiceAddress(
+ authenticatingService, gcpServiceAddress);
+ }
+
+ public abstract String authenticatingService();
+
+ public abstract HostAndPort gcpServiceAddress();
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java
index ef9156f9c05..aa15e0a5e1a 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java
@@ -20,19 +20,22 @@ package
org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
import static
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory.LOCALHOST;
import static
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory.localhostChannel;
-import java.util.ArrayList;
-import java.util.HashSet;
+import com.google.auto.value.AutoValue;
import java.util.List;
import java.util.Random;
import java.util.Set;
+import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc.CloudWindmillMetadataServiceV1Alpha1Stub;
import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress;
import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
import org.slf4j.Logger;
@@ -44,93 +47,178 @@ class GrpcDispatcherClient {
private static final Logger LOG =
LoggerFactory.getLogger(GrpcDispatcherClient.class);
private final WindmillStubFactory windmillStubFactory;
- @GuardedBy("this")
- private final List<CloudWindmillServiceV1Alpha1Stub> dispatcherStubs;
-
- @GuardedBy("this")
- private final Set<HostAndPort> dispatcherEndpoints;
+ /**
+ * Current dispatcher endpoints and stubs used to communicate with Windmill
Dispatcher.
+ *
+ * @implNote Reads are lock free, writes are synchronized.
+ */
+ private final AtomicReference<DispatcherStubs> dispatcherStubs;
@GuardedBy("this")
private final Random rand;
private GrpcDispatcherClient(
WindmillStubFactory windmillStubFactory,
- List<CloudWindmillServiceV1Alpha1Stub> dispatcherStubs,
- Set<HostAndPort> dispatcherEndpoints,
+ DispatcherStubs initialDispatcherStubs,
Random rand) {
this.windmillStubFactory = windmillStubFactory;
- this.dispatcherStubs = dispatcherStubs;
- this.dispatcherEndpoints = dispatcherEndpoints;
this.rand = rand;
+ this.dispatcherStubs = new AtomicReference<>(initialDispatcherStubs);
}
static GrpcDispatcherClient create(WindmillStubFactory windmillStubFactory) {
- return new GrpcDispatcherClient(
- windmillStubFactory, new ArrayList<>(), new HashSet<>(), new Random());
+ return new GrpcDispatcherClient(windmillStubFactory,
DispatcherStubs.empty(), new Random());
}
@VisibleForTesting
static GrpcDispatcherClient forTesting(
WindmillStubFactory windmillGrpcStubFactory,
- List<CloudWindmillServiceV1Alpha1Stub> dispatcherStubs,
+ List<CloudWindmillServiceV1Alpha1Stub> windmillServiceStubs,
+ List<CloudWindmillMetadataServiceV1Alpha1Stub>
windmillMetadataServiceStubs,
Set<HostAndPort> dispatcherEndpoints) {
- Preconditions.checkArgument(dispatcherEndpoints.size() ==
dispatcherStubs.size());
+ Preconditions.checkArgument(
+ dispatcherEndpoints.size() == windmillServiceStubs.size()
+ && windmillServiceStubs.size() ==
windmillMetadataServiceStubs.size());
return new GrpcDispatcherClient(
- windmillGrpcStubFactory, dispatcherStubs, dispatcherEndpoints, new
Random());
+ windmillGrpcStubFactory,
+ DispatcherStubs.create(
+ dispatcherEndpoints, windmillServiceStubs,
windmillMetadataServiceStubs),
+ new Random());
+ }
+
+ CloudWindmillServiceV1Alpha1Stub getWindmillServiceStub() {
+ ImmutableList<CloudWindmillServiceV1Alpha1Stub> windmillServiceStubs =
+ dispatcherStubs.get().windmillServiceStubs();
+ Preconditions.checkState(
+ !windmillServiceStubs.isEmpty(), "windmillServiceEndpoint has not been
set");
+
+ return (windmillServiceStubs.size() == 1
+ ? windmillServiceStubs.get(0)
+ : randomlySelectNextStub(windmillServiceStubs));
}
- synchronized CloudWindmillServiceV1Alpha1Stub getDispatcherStub() {
+ CloudWindmillMetadataServiceV1Alpha1Stub getWindmillMetadataServiceStub() {
+ ImmutableList<CloudWindmillMetadataServiceV1Alpha1Stub>
windmillMetadataServiceStubs =
+ dispatcherStubs.get().windmillMetadataServiceStubs();
Preconditions.checkState(
- !dispatcherStubs.isEmpty(), "windmillServiceEndpoint has not been
set");
+ !windmillMetadataServiceStubs.isEmpty(), "windmillServiceEndpoint has
not been set");
+
+ return (windmillMetadataServiceStubs.size() == 1
+ ? windmillMetadataServiceStubs.get(0)
+ : randomlySelectNextStub(windmillMetadataServiceStubs));
+ }
- return (dispatcherStubs.size() == 1
- ? dispatcherStubs.get(0)
- : dispatcherStubs.get(rand.nextInt(dispatcherStubs.size())));
+ private synchronized <T> T randomlySelectNextStub(List<T> stubs) {
+ return stubs.get(rand.nextInt(stubs.size()));
}
- synchronized boolean isReady() {
- return !dispatcherStubs.isEmpty();
+ /**
+ * Returns whether the {@link DispatcherStubs} have been set. Once initially
set, {@link
+ * #dispatcherStubs} will always have a value as empty updates will trigger
an {@link
+ * IllegalStateException}.
+ */
+ boolean hasInitializedEndpoints() {
+ return dispatcherStubs.get().hasInitializedEndpoints();
}
synchronized void consumeWindmillDispatcherEndpoints(
ImmutableSet<HostAndPort> dispatcherEndpoints) {
+ ImmutableSet<HostAndPort> currentDispatcherEndpoints =
+ dispatcherStubs.get().dispatcherEndpoints();
Preconditions.checkArgument(
dispatcherEndpoints != null && !dispatcherEndpoints.isEmpty(),
"Cannot set dispatcher endpoints to nothing.");
- if (this.dispatcherEndpoints.equals(dispatcherEndpoints)) {
+ if (currentDispatcherEndpoints.equals(dispatcherEndpoints)) {
// The endpoints are equal don't recreate the stubs.
return;
}
LOG.info("Creating a new windmill stub, endpoints: {}",
dispatcherEndpoints);
- if (!this.dispatcherEndpoints.isEmpty()) {
- LOG.info("Previous windmill stub endpoints: {}",
this.dispatcherEndpoints);
+ if (!currentDispatcherEndpoints.isEmpty()) {
+ LOG.info("Previous windmill stub endpoints: {}",
currentDispatcherEndpoints);
}
- resetDispatcherEndpoints(dispatcherEndpoints);
+ LOG.info("Initializing Streaming Engine GRPC client for endpoints: {}",
dispatcherEndpoints);
+ dispatcherStubs.set(DispatcherStubs.create(dispatcherEndpoints,
windmillStubFactory));
}
- private synchronized void resetDispatcherEndpoints(
- ImmutableSet<HostAndPort> newDispatcherEndpoints) {
- LOG.info("Initializing Streaming Engine GRPC client for endpoints: {}",
newDispatcherEndpoints);
- this.dispatcherStubs.clear();
- this.dispatcherEndpoints.clear();
- this.dispatcherEndpoints.addAll(newDispatcherEndpoints);
+ /**
+ * Endpoints and gRPC stubs used to communicate with the Windmill
Dispatcher. {@link
+ * #dispatcherEndpoints()}, {@link #windmillServiceStubs()}, and {@link
+ * #windmillMetadataServiceStubs()} collections should all be of the same
size.
+ */
+ @AutoValue
+ abstract static class DispatcherStubs {
- dispatcherEndpoints.stream()
- .map(this::createDispatcherStubForWindmillService)
- .forEach(dispatcherStubs::add);
- }
+ private static DispatcherStubs empty() {
+ return create(ImmutableSet.of(), ImmutableList.of(), ImmutableList.of());
+ }
- private CloudWindmillServiceV1Alpha1Stub
createDispatcherStubForWindmillService(
- HostAndPort endpoint) {
- if (LOCALHOST.equals(endpoint.getHost())) {
- return
CloudWindmillServiceV1Alpha1Grpc.newStub(localhostChannel(endpoint.getPort()));
+ private static DispatcherStubs create(
+ Set<HostAndPort> endpoints,
+ List<CloudWindmillServiceV1Alpha1Stub> windmillServiceStubs,
+ List<CloudWindmillMetadataServiceV1Alpha1Stub>
windmillMetadataServiceStubs) {
+ Preconditions.checkState(
+ endpoints.size() == windmillServiceStubs.size()
+ && windmillServiceStubs.size() ==
windmillMetadataServiceStubs.size(),
+ "Dispatcher should have the same number of endpoints and stubs");
+ return new AutoValue_GrpcDispatcherClient_DispatcherStubs(
+ ImmutableSet.copyOf(endpoints),
+ ImmutableList.copyOf(windmillServiceStubs),
+ ImmutableList.copyOf(windmillMetadataServiceStubs));
}
- // Use an in-process stub if testing.
- return windmillStubFactory.getKind() == WindmillStubFactory.Kind.IN_PROCESS
- ? windmillStubFactory.inProcess().get()
- :
windmillStubFactory.remote().apply(WindmillServiceAddress.create(endpoint));
+ private static DispatcherStubs create(
+ ImmutableSet<HostAndPort> newDispatcherEndpoints, WindmillStubFactory
windmillStubFactory) {
+ ImmutableList.Builder<CloudWindmillServiceV1Alpha1Stub>
windmillServiceStubs =
+ ImmutableList.builder();
+ ImmutableList.Builder<CloudWindmillMetadataServiceV1Alpha1Stub>
windmillMetadataServiceStubs =
+ ImmutableList.builder();
+
+ for (HostAndPort endpoint : newDispatcherEndpoints) {
+ windmillServiceStubs.add(createWindmillServiceStub(endpoint,
windmillStubFactory));
+ windmillMetadataServiceStubs.add(
+ createWindmillMetadataServiceStub(endpoint, windmillStubFactory));
+ }
+
+ return new AutoValue_GrpcDispatcherClient_DispatcherStubs(
+ newDispatcherEndpoints,
+ windmillServiceStubs.build(),
+ windmillMetadataServiceStubs.build());
+ }
+
+ private static CloudWindmillServiceV1Alpha1Stub createWindmillServiceStub(
+ HostAndPort endpoint, WindmillStubFactory windmillStubFactory) {
+ if (LOCALHOST.equals(endpoint.getHost())) {
+ return
CloudWindmillServiceV1Alpha1Grpc.newStub(localhostChannel(endpoint.getPort()));
+ }
+
+ return
windmillStubFactory.createWindmillServiceStub(WindmillServiceAddress.create(endpoint));
+ }
+
+ private static CloudWindmillMetadataServiceV1Alpha1Stub
createWindmillMetadataServiceStub(
+ HostAndPort endpoint, WindmillStubFactory windmillStubFactory) {
+ if (LOCALHOST.equals(endpoint.getHost())) {
+ return CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(
+ localhostChannel(endpoint.getPort()));
+ }
+
+ return windmillStubFactory.createWindmillMetadataServiceStub(
+ WindmillServiceAddress.create(endpoint));
+ }
+
+ private int size() {
+ return dispatcherEndpoints().size();
+ }
+
+ private boolean hasInitializedEndpoints() {
+ return size() > 0;
+ }
+
+ abstract ImmutableSet<HostAndPort> dispatcherEndpoints();
+
+ abstract ImmutableList<CloudWindmillServiceV1Alpha1Stub>
windmillServiceStubs();
+
+ abstract ImmutableList<CloudWindmillMetadataServiceV1Alpha1Stub>
windmillMetadataServiceStubs();
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
index fbed81c1153..858aeb15985 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
@@ -33,6 +33,8 @@ import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.DataflowRunner;
import
org.apache.beam.runners.dataflow.worker.options.StreamingDataflowWorkerOptions;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc.CloudWindmillMetadataServiceV1Alpha1Stub;
import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
@@ -53,6 +55,7 @@ import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.RemoteWindmillStubFactory;
import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory;
import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.StreamingEngineThrottleTimers;
import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver;
@@ -84,14 +87,14 @@ import org.slf4j.LoggerFactory;
})
@SuppressWarnings("nullness") //
TODO(https://github.com/apache/beam/issues/20497
public final class GrpcWindmillServer extends WindmillServerStub {
+ public static final Duration LOCALHOST_MAX_BACKOFF = Duration.millis(500);
+ public static final Duration MAX_BACKOFF = Duration.standardSeconds(30);
+ private static final Duration MIN_BACKOFF = Duration.millis(1);
private static final Logger LOG =
LoggerFactory.getLogger(GrpcWindmillServer.class);
private static final int DEFAULT_LOG_EVERY_N_FAILURES = 20;
- private static final Duration MIN_BACKOFF = Duration.millis(1);
- private static final Duration MAX_BACKOFF = Duration.standardSeconds(30);
private static final int NO_HEALTH_CHECK = -1;
private static final String GRPC_LOCALHOST = "grpc:localhost";
- private final GrpcWindmillStreamFactory windmillStreamFactory;
private final GrpcDispatcherClient dispatcherClient;
private final StreamingDataflowWorkerOptions options;
private final StreamingEngineThrottleTimers throttleTimers;
@@ -100,44 +103,26 @@ public final class GrpcWindmillServer extends
WindmillServerStub {
// If true, then active work refreshes will be sent as KeyedGetDataRequests.
Otherwise, use the
// newer ComputationHeartbeatRequests.
private final boolean sendKeyedGetDataRequests;
- private Consumer<List<ComputationHeartbeatResponse>>
processHeartbeatResponses;
+ private final Consumer<List<ComputationHeartbeatResponse>>
processHeartbeatResponses;
+ private final GrpcWindmillStreamFactory windmillStreamFactory;
private GrpcWindmillServer(
- StreamingDataflowWorkerOptions options, GrpcDispatcherClient
grpcDispatcherClient) {
+ StreamingDataflowWorkerOptions options,
+ GrpcWindmillStreamFactory grpcWindmillStreamFactory,
+ GrpcDispatcherClient grpcDispatcherClient,
+ Consumer<List<Windmill.ComputationHeartbeatResponse>>
processHeartbeatResponses) {
this.options = options;
this.throttleTimers = StreamingEngineThrottleTimers.create();
this.maxBackoff = MAX_BACKOFF;
- this.windmillStreamFactory =
- GrpcWindmillStreamFactory.of(
- JobHeader.newBuilder()
- .setJobId(options.getJobId())
- .setProjectId(options.getProject())
- .setWorkerId(options.getWorkerId())
- .build())
- .setWindmillMessagesBetweenIsReadyChecks(
- options.getWindmillMessagesBetweenIsReadyChecks())
- .setMaxBackOffSupplier(() -> maxBackoff)
- .setLogEveryNStreamFailures(
- options.getWindmillServiceStreamingLogEveryNStreamFailures())
-
.setStreamingRpcBatchLimit(options.getWindmillServiceStreamingRpcBatchLimit())
- .build();
- windmillStreamFactory.scheduleHealthChecks(
- options.getWindmillServiceStreamingRpcHealthCheckPeriodMs());
-
this.dispatcherClient = grpcDispatcherClient;
this.syncApplianceStub = null;
this.sendKeyedGetDataRequests =
!options.isEnableStreamingEngine()
|| !DataflowRunner.hasExperiment(
options, "streaming_engine_send_new_heartbeat_requests");
- this.processHeartbeatResponses = (responses) -> {};
- }
-
- @Override
- public void setProcessHeartbeatResponses(
- Consumer<List<Windmill.ComputationHeartbeatResponse>>
processHeartbeatResponses) {
this.processHeartbeatResponses = processHeartbeatResponses;
- };
+ this.windmillStreamFactory = grpcWindmillStreamFactory;
+ }
private static StreamingDataflowWorkerOptions testOptions(
boolean enableStreamingEngine, List<String> additionalExperiments) {
@@ -162,17 +147,22 @@ public final class GrpcWindmillServer extends
WindmillServerStub {
}
/** Create new instance of {@link GrpcWindmillServer}. */
- public static GrpcWindmillServer create(StreamingDataflowWorkerOptions
workerOptions)
+ public static GrpcWindmillServer create(
+ StreamingDataflowWorkerOptions workerOptions,
+ GrpcWindmillStreamFactory grpcWindmillStreamFactory,
+ Consumer<List<Windmill.ComputationHeartbeatResponse>>
processHeartbeatResponses)
throws IOException {
GrpcWindmillServer grpcWindmillServer =
new GrpcWindmillServer(
workerOptions,
+ grpcWindmillStreamFactory,
GrpcDispatcherClient.create(
- WindmillStubFactory.remoteStubFactory(
+ new RemoteWindmillStubFactory(
workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(),
workerOptions.getGcpCredential(),
- workerOptions.getUseWindmillIsolatedChannels())));
+ workerOptions.getUseWindmillIsolatedChannels())),
+ processHeartbeatResponses);
if (workerOptions.getWindmillServiceEndpoint() != null) {
grpcWindmillServer.configureWindmillServiceEndpoints();
} else if (!workerOptions.isEnableStreamingEngine()
@@ -184,32 +174,62 @@ public final class GrpcWindmillServer extends
WindmillServerStub {
}
@VisibleForTesting
- static GrpcWindmillServer newTestInstance(String name, List<String>
experiments) {
+ static GrpcWindmillServer newTestInstance(
+ String name,
+ List<String> experiments,
+ long clientId,
+ WindmillStubFactory windmillStubFactory) {
ManagedChannel inProcessChannel = inProcessChannel(name);
CloudWindmillServiceV1Alpha1Stub stub =
CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel);
- List<CloudWindmillServiceV1Alpha1Stub> dispatcherStubs =
Lists.newArrayList(stub);
+ CloudWindmillMetadataServiceV1Alpha1Stub metadataStub =
+ CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel);
+ List<CloudWindmillServiceV1Alpha1Stub> windmillServiceStubs =
Lists.newArrayList(stub);
+ List<CloudWindmillMetadataServiceV1Alpha1Stub>
windmillMetadataServiceStubs =
+ Lists.newArrayList(metadataStub);
+
Set<HostAndPort> dispatcherEndpoints =
Sets.newHashSet(HostAndPort.fromHost(name));
GrpcDispatcherClient dispatcherClient =
GrpcDispatcherClient.forTesting(
- WindmillStubFactory.inProcessStubFactory(name, unused ->
inProcessChannel),
- dispatcherStubs,
+ windmillStubFactory,
+ windmillServiceStubs,
+ windmillMetadataServiceStubs,
dispatcherEndpoints);
- return new GrpcWindmillServer(
- testOptions(/* enableStreamingEngine= */ true, experiments),
dispatcherClient);
+
+ StreamingDataflowWorkerOptions testOptions =
+ testOptions(/* enableStreamingEngine= */ true, experiments);
+ GrpcWindmillStreamFactory windmillStreamFactory =
+ GrpcWindmillStreamFactory.of(createJobHeader(testOptions,
clientId)).build();
+ windmillStreamFactory.scheduleHealthChecks(
+ testOptions.getWindmillServiceStreamingRpcHealthCheckPeriodMs());
+ return new GrpcWindmillServer(testOptions, windmillStreamFactory,
dispatcherClient, noop -> {});
}
@VisibleForTesting
- static GrpcWindmillServer newApplianceTestInstance(Channel channel) {
+ static GrpcWindmillServer newApplianceTestInstance(
+ Channel channel, WindmillStubFactory windmillStubFactory) {
+ StreamingDataflowWorkerOptions options =
+ testOptions(/* enableStreamingEngine= */ false, new ArrayList<>());
GrpcWindmillServer testServer =
new GrpcWindmillServer(
- testOptions(/* enableStreamingEngine= */ false, new ArrayList<>()),
+ options,
+ GrpcWindmillStreamFactory.of(createJobHeader(options, 1)).build(),
// No-op, Appliance does not use Dispatcher to call Streaming
Engine.
-
GrpcDispatcherClient.create(WindmillStubFactory.inProcessStubFactory("test")));
+ GrpcDispatcherClient.create(windmillStubFactory),
+ noop -> {});
testServer.syncApplianceStub =
createWindmillApplianceStubWithDeadlineInterceptor(channel);
return testServer;
}
+ private static JobHeader createJobHeader(StreamingDataflowWorkerOptions
options, long clientId) {
+ return Windmill.JobHeader.newBuilder()
+ .setJobId(options.getJobId())
+ .setProjectId(options.getProject())
+ .setWorkerId(options.getWorkerId())
+ .setClientId(clientId)
+ .build();
+ }
+
private static WindmillApplianceGrpc.WindmillApplianceBlockingStub
createWindmillApplianceStubWithDeadlineInterceptor(Channel channel) {
return WindmillApplianceGrpc.newBlockingStub(channel)
@@ -249,7 +269,7 @@ public final class GrpcWindmillServer extends
WindmillServerStub {
@Override
public boolean isReady() {
- return dispatcherClient.isReady();
+ return dispatcherClient.hasInitializedEndpoints();
}
private synchronized void initializeLocalHost(int port) {
@@ -329,7 +349,7 @@ public final class GrpcWindmillServer extends
WindmillServerStub {
@Override
public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver
receiver) {
return windmillStreamFactory.createGetWorkStream(
- dispatcherClient.getDispatcherStub(),
+ dispatcherClient.getWindmillServiceStub(),
GetWorkRequest.newBuilder(request)
.setJobId(options.getJobId())
.setProjectId(options.getProject())
@@ -342,7 +362,7 @@ public final class GrpcWindmillServer extends
WindmillServerStub {
@Override
public GetDataStream getDataStream() {
return windmillStreamFactory.createGetDataStream(
- dispatcherClient.getDispatcherStub(),
+ dispatcherClient.getWindmillServiceStub(),
throttleTimers.getDataThrottleTimer(),
sendKeyedGetDataRequests,
this.processHeartbeatResponses);
@@ -351,7 +371,7 @@ public final class GrpcWindmillServer extends
WindmillServerStub {
@Override
public CommitWorkStream commitWorkStream() {
return windmillStreamFactory.createCommitWorkStream(
- dispatcherClient.getDispatcherStub(),
throttleTimers.commitWorkThrottleTimer());
+ dispatcherClient.getWindmillServiceStub(),
throttleTimers.commitWorkThrottleTimer());
}
@Override
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
index 7dc43e791e3..8696c464a0f 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
@@ -32,6 +32,7 @@ import java.util.function.Consumer;
import java.util.function.Supplier;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc.CloudWindmillMetadataServiceV1Alpha1Stub;
import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
@@ -49,6 +50,7 @@ import
org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.FluentBackoff;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.AbstractStub;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
import org.joda.time.Duration;
import org.joda.time.Instant;
@@ -109,8 +111,7 @@ public class GrpcWindmillStreamFactory implements
StatusDataProvider {
.setStreamingRpcBatchLimit(DEFAULT_STREAMING_RPC_BATCH_LIMIT);
}
- private static CloudWindmillServiceV1Alpha1Stub withDeadline(
- CloudWindmillServiceV1Alpha1Stub stub) {
+ private static <T extends AbstractStub<T>> T withDefaultDeadline(T stub) {
// Deadlines are absolute points in time, so generate a new one everytime
this function is
// called.
return stub.withDeadlineAfter(
@@ -123,7 +124,7 @@ public class GrpcWindmillStreamFactory implements
StatusDataProvider {
ThrottleTimer getWorkThrottleTimer,
WorkItemReceiver processWorkItem) {
return GrpcGetWorkStream.create(
- responseObserver -> withDeadline(stub).getWorkStream(responseObserver),
+ responseObserver ->
withDefaultDeadline(stub).getWorkStream(responseObserver),
request,
grpcBackOff.get(),
newStreamObserverFactory(),
@@ -141,7 +142,7 @@ public class GrpcWindmillStreamFactory implements
StatusDataProvider {
Supplier<CommitWorkStream> commitWorkStream,
WorkItemProcessor workItemProcessor) {
return GrpcDirectGetWorkStream.create(
- responseObserver -> withDeadline(stub).getWorkStream(responseObserver),
+ responseObserver ->
withDefaultDeadline(stub).getWorkStream(responseObserver),
request,
grpcBackOff.get(),
newStreamObserverFactory(),
@@ -159,7 +160,7 @@ public class GrpcWindmillStreamFactory implements
StatusDataProvider {
boolean sendKeyedGetDataRequests,
Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses) {
return GrpcGetDataStream.create(
- responseObserver -> withDeadline(stub).getDataStream(responseObserver),
+ responseObserver ->
withDefaultDeadline(stub).getDataStream(responseObserver),
grpcBackOff.get(),
newStreamObserverFactory(),
streamRegistry,
@@ -180,7 +181,7 @@ public class GrpcWindmillStreamFactory implements
StatusDataProvider {
public CommitWorkStream createCommitWorkStream(
CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer
commitWorkThrottleTimer) {
return GrpcCommitWorkStream.create(
- responseObserver ->
withDeadline(stub).commitWorkStream(responseObserver),
+ responseObserver ->
withDefaultDeadline(stub).commitWorkStream(responseObserver),
grpcBackOff.get(),
newStreamObserverFactory(),
streamRegistry,
@@ -192,11 +193,11 @@ public class GrpcWindmillStreamFactory implements
StatusDataProvider {
}
public GetWorkerMetadataStream createGetWorkerMetadataStream(
- CloudWindmillServiceV1Alpha1Stub stub,
+ CloudWindmillMetadataServiceV1Alpha1Stub stub,
ThrottleTimer getWorkerMetadataThrottleTimer,
Consumer<WindmillEndpoints> onNewWindmillEndpoints) {
return GrpcGetWorkerMetadataStream.create(
- responseObserver ->
withDeadline(stub).getWorkerMetadataStream(responseObserver),
+ responseObserver ->
withDefaultDeadline(stub).getWorkerMetadata(responseObserver),
grpcBackOff.get(),
newStreamObserverFactory(),
streamRegistry,
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java
index 01783f6aa4d..80c957996ab 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java
@@ -131,7 +131,7 @@ public final class StreamingEngineClient {
Suppliers.memoize(
() ->
streamFactory.createGetWorkerMetadataStream(
- dispatcherClient.getDispatcherStub(),
+ dispatcherClient.getWindmillMetadataServiceStub(),
getWorkerMetadataThrottleTimer,
endpoints ->
// Run this on a separate thread than the grpc stream
thread.
@@ -267,7 +267,7 @@ public final class StreamingEngineClient {
getWorkBudgetRefresher.requestBudgetRefresh();
}
- public final ImmutableList<Long> getAndResetThrottleTimes() {
+ public ImmutableList<Long> getAndResetThrottleTimes() {
StreamingEngineConnectionState currentConnections = connections.get();
ImmutableList<Long> keyedWorkStreamThrottleTimes =
@@ -375,21 +375,10 @@ public final class StreamingEngineClient {
}
private CloudWindmillServiceV1Alpha1Stub createWindmillStub(Endpoint
endpoint) {
- switch (stubFactory.getKind()) {
- // This is only used in tests.
- case IN_PROCESS:
- return stubFactory.inProcess().get();
- // Create stub for direct_endpoint or just default to Dispatcher stub.
- case REMOTE:
- return endpoint
- .directEndpoint()
- .map(stubFactory.remote())
- .orElseGet(dispatcherClient::getDispatcherStub);
- // Should never be called, this switch statement is exhaustive.
- default:
- throw new UnsupportedOperationException(
- "Only remote or in-process stub factories are available.");
- }
+ return endpoint
+ .directEndpoint()
+ .map(stubFactory::createWindmillServiceStub)
+ .orElseGet(dispatcherClient::getWindmillServiceStub);
}
private static class StreamingEngineClientException extends
IllegalStateException {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/RemoteWindmillStubFactory.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/RemoteWindmillStubFactory.java
new file mode 100644
index 00000000000..9978b74c7aa
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/RemoteWindmillStubFactory.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs;
+
+import static
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory.remoteChannel;
+
+import com.google.auth.Credentials;
+import java.util.function.Supplier;
+import javax.annotation.concurrent.ThreadSafe;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc.CloudWindmillMetadataServiceV1Alpha1Stub;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.auth.VendoredCredentialsAdapter;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.auth.MoreCallCredentials;
+
+/** Creates remote stubs to talk to Streaming Engine. */
+@Internal
+@ThreadSafe
+public final class RemoteWindmillStubFactory implements WindmillStubFactory {
+ private final int rpcChannelTimeoutSec;
+ private final Credentials gcpCredentials;
+ private final boolean useIsolatedChannels;
+
+ public RemoteWindmillStubFactory(
+ int rpcChannelTimeoutSec, Credentials gcpCredentials, boolean
useIsolatedChannels) {
+ this.rpcChannelTimeoutSec = rpcChannelTimeoutSec;
+ this.gcpCredentials = gcpCredentials;
+ this.useIsolatedChannels = useIsolatedChannels;
+ }
+
+ @Override
+ public CloudWindmillServiceV1Alpha1Stub createWindmillServiceStub(
+ WindmillServiceAddress serviceAddress) {
+ CloudWindmillServiceV1Alpha1Stub windmillServiceStub =
+
CloudWindmillServiceV1Alpha1Grpc.newStub(createChannel(serviceAddress));
+ return serviceAddress.getKind() !=
WindmillServiceAddress.Kind.AUTHENTICATED_GCP_SERVICE_ADDRESS
+ ? windmillServiceStub.withCallCredentials(
+ MoreCallCredentials.from(new
VendoredCredentialsAdapter(gcpCredentials)))
+ : windmillServiceStub;
+ }
+
+ @Override
+ public CloudWindmillMetadataServiceV1Alpha1Stub
createWindmillMetadataServiceStub(
+ WindmillServiceAddress serviceAddress) {
+ return
CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(createChannel(serviceAddress))
+ .withCallCredentials(
+ MoreCallCredentials.from(new
VendoredCredentialsAdapter(gcpCredentials)));
+ }
+
+ private ManagedChannel createChannel(WindmillServiceAddress serviceAddress) {
+ Supplier<ManagedChannel> channelFactory =
+ () -> remoteChannel(serviceAddress, rpcChannelTimeoutSec);
+ // IsolationChannel will create and manage separate RPC channels to the
same serviceAddress via
+ // calling the channelFactory, else just directly return the RPC channel.
+ return useIsolatedChannels ? IsolationChannel.create(channelFactory) :
channelFactory.get();
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java
index 68c82c5907b..cf31436d364 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java
@@ -22,8 +22,11 @@ import java.net.InetSocketAddress;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLException;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Channel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ForwardingChannelBuilder2;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.alts.AltsChannelBuilder;
import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.netty.GrpcSslContexts;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.netty.NegotiationType;
@@ -57,12 +60,28 @@ public final class WindmillChannelFactory {
return remoteChannel(
windmillServiceAddress.gcpServiceAddress(),
windmillServiceRpcChannelTimeoutSec);
// switch is exhaustive will never happen.
+ case AUTHENTICATED_GCP_SERVICE_ADDRESS:
+ return remoteDirectChannel(
+ windmillServiceAddress.authenticatedGcpServiceAddress(),
+ windmillServiceRpcChannelTimeoutSec);
default:
throw new UnsupportedOperationException(
- "Only IPV6 and GCP_SERVICE_ADDRESS are supported
WindmillServiceAddresses.");
+ "Only IPV6, GCP_SERVICE_ADDRESS, AUTHENTICATED_GCP_SERVICE_ADDRESS
are supported WindmillServiceAddresses.");
}
}
+ static ManagedChannel remoteDirectChannel(
+ AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress,
+ int windmillServiceRpcChannelTimeoutSec) {
+ return withDefaultChannelOptions(
+ AltsChannelBuilder.forAddress(
+
authenticatedGcpServiceAddress.gcpServiceAddress().getHost(),
+
authenticatedGcpServiceAddress.gcpServiceAddress().getPort())
+
.overrideAuthority(authenticatedGcpServiceAddress.authenticatingService()),
+ windmillServiceRpcChannelTimeoutSec)
+ .build();
+ }
+
public static ManagedChannel remoteChannel(
HostAndPort endpoint, int windmillServiceRpcChannelTimeoutSec) {
try {
@@ -100,6 +119,17 @@ public final class WindmillChannelFactory {
private static ManagedChannel createRemoteChannel(
NettyChannelBuilder channelBuilder, int
windmillServiceRpcChannelTimeoutSec)
throws SSLException {
+ return withDefaultChannelOptions(channelBuilder,
windmillServiceRpcChannelTimeoutSec)
+ .flowControlWindow(10 * 1024 * 1024)
+ .negotiationType(NegotiationType.TLS)
+ // Set ciphers(null) to not use GCM, which is disabled for Dataflow
+ // due to it being horribly slow.
+ .sslContext(GrpcSslContexts.forClient().ciphers(null).build())
+ .build();
+ }
+
+ private static <T extends ForwardingChannelBuilder2<T>> T
withDefaultChannelOptions(
+ T channelBuilder, int windmillServiceRpcChannelTimeoutSec) {
if (windmillServiceRpcChannelTimeoutSec > 0) {
channelBuilder
.keepAliveTime(windmillServiceRpcChannelTimeoutSec, TimeUnit.SECONDS)
@@ -108,14 +138,8 @@ public final class WindmillChannelFactory {
}
return channelBuilder
- .flowControlWindow(10 * 1024 * 1024)
.maxInboundMessageSize(Integer.MAX_VALUE)
- .maxInboundMetadataSize(1024 * 1024)
- .negotiationType(NegotiationType.TLS)
- // Set ciphers(null) to not use GCM, which is disabled for Dataflow
- // due to it being horribly slow.
- .sslContext(GrpcSslContexts.forClient().ciphers(null).build())
- .build();
+ .maxInboundMetadataSize(1024 * 1024);
}
public static class WindmillChannelCreationException extends
IllegalStateException {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillStubFactory.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillStubFactory.java
index 7ad46a21c08..e5e523445a6 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillStubFactory.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillStubFactory.java
@@ -17,62 +17,16 @@
*/
package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs;
-import static
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory.remoteChannel;
-
-import com.google.auth.Credentials;
-import com.google.auto.value.AutoOneOf;
-import java.util.function.Function;
-import java.util.function.Supplier;
-import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc.CloudWindmillMetadataServiceV1Alpha1Stub;
import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress;
-import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.auth.VendoredCredentialsAdapter;
-import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
-import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.auth.MoreCallCredentials;
-
-/**
- * Used to create stubs to talk to Streaming Engine. Stubs are either
in-process for testing, or
- * remote.
- */
-@AutoOneOf(WindmillStubFactory.Kind.class)
-public abstract class WindmillStubFactory {
-
- public static WindmillStubFactory inProcessStubFactory(
- String testName, Function<String, ManagedChannel> channelFactory) {
- return AutoOneOf_WindmillStubFactory.inProcess(
- () ->
CloudWindmillServiceV1Alpha1Grpc.newStub(channelFactory.apply(testName)));
- }
-
- public static WindmillStubFactory inProcessStubFactory(String testName) {
- return AutoOneOf_WindmillStubFactory.inProcess(
- () ->
- CloudWindmillServiceV1Alpha1Grpc.newStub(
- WindmillChannelFactory.inProcessChannel(testName)));
- }
-
- public static WindmillStubFactory remoteStubFactory(
- int rpcChannelTimeoutSec, Credentials gcpCredentials, boolean
useIsolatedChannels) {
- return AutoOneOf_WindmillStubFactory.remote(
- directEndpoint -> {
- Supplier<ManagedChannel> channelSupplier =
- () -> remoteChannel(directEndpoint, rpcChannelTimeoutSec);
- return CloudWindmillServiceV1Alpha1Grpc.newStub(
- useIsolatedChannels
- ? IsolationChannel.create(channelSupplier)
- : channelSupplier.get())
- .withCallCredentials(
- MoreCallCredentials.from(new
VendoredCredentialsAdapter(gcpCredentials)));
- });
- }
-
- public abstract Kind getKind();
-
- public abstract Supplier<CloudWindmillServiceV1Alpha1Stub> inProcess();
+import org.apache.beam.sdk.annotations.Internal;
- public abstract Function<WindmillServiceAddress,
CloudWindmillServiceV1Alpha1Stub> remote();
+/** Used to create stubs to talk to Streaming Engine. */
+@Internal
+public interface WindmillStubFactory {
+ CloudWindmillServiceV1Alpha1Stub
createWindmillServiceStub(WindmillServiceAddress serviceAddress);
- public enum Kind {
- IN_PROCESS,
- REMOTE
- }
+ CloudWindmillMetadataServiceV1Alpha1Stub createWindmillMetadataServiceStub(
+ WindmillServiceAddress serviceAddress);
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
index 2cfec6d3139..069fcac07c8 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
@@ -32,6 +32,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
@@ -42,6 +43,8 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
+import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
+import
org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationCommitWorkRequest;
@@ -70,7 +73,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** An in-memory Windmill server that offers provided work and data. */
-class FakeWindmillServer extends WindmillServerStub {
+final class FakeWindmillServer extends WindmillServerStub {
private static final Logger LOG =
LoggerFactory.getLogger(FakeWindmillServer.class);
private final ResponseQueue<Windmill.GetWorkRequest,
Windmill.GetWorkResponse> workToOffer;
private final ResponseQueue<GetDataRequest, GetDataResponse> dataToOffer;
@@ -86,9 +89,11 @@ class FakeWindmillServer extends WindmillServerStub {
private final List<Windmill.GetDataRequest> getDataRequests = new
ArrayList<>();
private boolean isReady = true;
private boolean dropStreamingCommits = false;
- private Consumer<List<Windmill.ComputationHeartbeatResponse>>
processHeartbeatResponses;
+ private final Consumer<List<Windmill.ComputationHeartbeatResponse>>
processHeartbeatResponses;
- public FakeWindmillServer(ErrorCollector errorCollector) {
+ public FakeWindmillServer(
+ ErrorCollector errorCollector,
+ Function<String, Optional<ComputationState>> computationStateFetcher) {
workToOffer =
new ResponseQueue<Windmill.GetWorkRequest, Windmill.GetWorkResponse>()
.returnByDefault(Windmill.GetWorkResponse.getDefaultInstance());
@@ -106,13 +111,7 @@ class FakeWindmillServer extends WindmillServerStub {
this.errorCollector = errorCollector;
statsReceived = new ArrayList<>();
droppedStreamingCommits = new ConcurrentHashMap<>();
- processHeartbeatResponses = (responses) -> {};
- }
-
- @Override
- public void setProcessHeartbeatResponses(
- Consumer<List<Windmill.ComputationHeartbeatResponse>>
processHeartbeatResponses) {
- this.processHeartbeatResponses = processHeartbeatResponses;
+ this.processHeartbeatResponses = new
WorkHeartbeatResponseProcessor(computationStateFetcher);
}
public void setDropStreamingCommits(boolean dropStreamingCommits) {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index abd2cbbac6e..df806fcb978 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -70,8 +70,10 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
+import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
@@ -179,6 +181,7 @@ import org.hamcrest.Matchers;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Assert;
+import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
@@ -261,6 +264,7 @@ public class StreamingDataflowWorkerTest {
return idGenerator.getAndIncrement();
}
};
+
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);
@Rule public BlockingFn blockingFn = new BlockingFn();
@Rule public TestRule restoreMDC = new RestoreDataflowLoggingMDC();
@@ -268,6 +272,10 @@ public class StreamingDataflowWorkerTest {
WorkUnitClient mockWorkUnitClient = mock(WorkUnitClient.class);
HotKeyLogger hotKeyLogger = mock(HotKeyLogger.class);
+ private final ConcurrentMap<String, ComputationState> computationMap = new
ConcurrentHashMap<>();
+ private final FakeWindmillServer server =
+ new FakeWindmillServer(errorCollector, id ->
Optional.ofNullable(computationMap.get(id)));
+
public StreamingDataflowWorkerTest(Boolean streamingEngine) {
this.streamingEngine = streamingEngine;
}
@@ -286,6 +294,12 @@ public class StreamingDataflowWorkerTest {
return null;
}
+ @Before
+ public void setUp() {
+ computationMap.clear();
+ server.clearCommitsReceived();
+ }
+
static Work createMockWork(long workToken) {
return createMockWork(workToken, work -> {});
}
@@ -760,8 +774,7 @@ public class StreamingDataflowWorkerTest {
return output.toByteString();
}
- private StreamingDataflowWorkerOptions createTestingPipelineOptions(
- FakeWindmillServer server, String... args) {
+ private StreamingDataflowWorkerOptions
createTestingPipelineOptions(String... args) {
List<String> argsList = Lists.newArrayList(args);
if (streamingEngine) {
argsList.add("--experiments=enable_streaming_engine");
@@ -771,8 +784,9 @@ public class StreamingDataflowWorkerTest {
.as(StreamingDataflowWorkerOptions.class);
options.setAppName("StreamingWorkerHarnessTest");
options.setJobId("test_job_id");
+ options.setProject("test_project");
+ options.setWorkerId("test_worker");
options.setStreaming(true);
- options.setWindmillServerStub(server);
options.setActiveWorkRefreshPeriodMillis(0);
return options;
}
@@ -782,10 +796,12 @@ public class StreamingDataflowWorkerTest {
StreamingDataflowWorkerOptions options,
boolean publishCounters,
Supplier<Instant> clock,
- Function<String, ScheduledExecutorService> executorSupplier)
- throws Exception {
+ Function<String, ScheduledExecutorService> executorSupplier) {
StreamingDataflowWorker worker =
new StreamingDataflowWorker(
+ server,
+ new Random().nextLong(),
+ computationMap,
Collections.singletonList(defaultMapTask(instructions)),
IntrinsicMapTaskExecutorFactory.defaultFactory(),
mockWorkUnitClient,
@@ -802,8 +818,7 @@ public class StreamingDataflowWorkerTest {
private StreamingDataflowWorker makeWorker(
List<ParallelInstruction> instructions,
StreamingDataflowWorkerOptions options,
- boolean publishCounters)
- throws Exception {
+ boolean publishCounters) {
return makeWorker(
instructions,
options,
@@ -819,9 +834,8 @@ public class StreamingDataflowWorkerTest {
makeSourceInstruction(StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
- StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
+ StreamingDataflowWorker worker =
+ makeWorker(instructions, createTestingPipelineOptions(), true /*
publishCounters */);
worker.start();
final int numIters = 2000;
@@ -848,9 +862,7 @@ public class StreamingDataflowWorkerTest {
makeSourceInstruction(StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
server.setIsReady(false);
-
StreamingConfigTask streamingConfig = new StreamingConfigTask();
streamingConfig.setStreamingComputationConfigs(
ImmutableList.of(makeDefaultStreamingComputationConfig(instructions)));
@@ -859,9 +871,8 @@ public class StreamingDataflowWorkerTest {
workItem.setStreamingConfigTask(streamingConfig);
when(mockWorkUnitClient.getGlobalStreamingConfigWorkItem()).thenReturn(Optional.of(workItem));
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
- options.setWindmillServiceCommitThreads(numCommitThreads);
- StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
+ StreamingDataflowWorker worker =
+ makeWorker(instructions, createTestingPipelineOptions(), true /*
publishCounters */);
worker.start();
final int numIters = 2000;
@@ -900,7 +911,6 @@ public class StreamingDataflowWorkerTest {
makeSourceInstruction(KvCoder.of(StringUtf8Coder.of(),
StringUtf8Coder.of())),
makeSinkInstruction(KvCoder.of(StringUtf8Coder.of(),
StringUtf8Coder.of()), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
server.setIsReady(false);
StreamingConfigTask streamingConfig = new StreamingConfigTask();
@@ -911,9 +921,11 @@ public class StreamingDataflowWorkerTest {
workItem.setStreamingConfigTask(streamingConfig);
when(mockWorkUnitClient.getGlobalStreamingConfigWorkItem()).thenReturn(Optional.of(workItem));
- StreamingDataflowWorkerOptions options =
- createTestingPipelineOptions(server, "--hotKeyLoggingEnabled=true");
- StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
+ StreamingDataflowWorker worker =
+ makeWorker(
+ instructions,
+ createTestingPipelineOptions("--hotKeyLoggingEnabled=true"),
+ true /* publishCounters */);
worker.start();
final int numIters = 2000;
@@ -938,7 +950,6 @@ public class StreamingDataflowWorkerTest {
makeSourceInstruction(KvCoder.of(StringUtf8Coder.of(),
StringUtf8Coder.of())),
makeSinkInstruction(KvCoder.of(StringUtf8Coder.of(),
StringUtf8Coder.of()), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
server.setIsReady(false);
StreamingConfigTask streamingConfig = new StreamingConfigTask();
@@ -949,8 +960,8 @@ public class StreamingDataflowWorkerTest {
workItem.setStreamingConfigTask(streamingConfig);
when(mockWorkUnitClient.getGlobalStreamingConfigWorkItem()).thenReturn(Optional.of(workItem));
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
- StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
+ StreamingDataflowWorker worker =
+ makeWorker(instructions, createTestingPipelineOptions(), true /*
publishCounters */);
worker.start();
final int numIters = 2000;
@@ -975,10 +986,8 @@ public class StreamingDataflowWorkerTest {
makeDoFnInstruction(blockingFn, 0, StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
-
- StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
+ StreamingDataflowWorker worker =
+ makeWorker(instructions, createTestingPipelineOptions(), true /*
publishCounters */);
worker.start();
for (int i = 0; i < numIters; ++i) {
@@ -1099,8 +1108,7 @@ public class StreamingDataflowWorkerTest {
makeDoFnInstruction(blockingFn, 0, StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ StreamingDataflowWorkerOptions options = createTestingPipelineOptions();
options.setNumberOfWorkerHarnessThreads(expectedNumberOfThreads);
StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
@@ -1143,13 +1151,12 @@ public class StreamingDataflowWorkerTest {
makeDoFnInstruction(new KeyTokenInvalidFn(), 0, kvCoder),
makeSinkInstruction(kvCoder, 1));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
server
.whenGetWorkCalled()
.thenReturn(makeInput(0, 0, DEFAULT_KEY_STRING, DEFAULT_SHARDING_KEY));
StreamingDataflowWorker worker =
- makeWorker(instructions, createTestingPipelineOptions(server), true /*
publishCounters */);
+ makeWorker(instructions, createTestingPipelineOptions(), true /*
publishCounters */);
worker.start();
server.waitForEmptyWorkQueue();
@@ -1177,11 +1184,10 @@ public class StreamingDataflowWorkerTest {
makeDoFnInstruction(new LargeCommitFn(), 0, kvCoder),
makeSinkInstruction(kvCoder, 1));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
server.setExpectedExceptionCount(1);
StreamingDataflowWorker worker =
- makeWorker(instructions, createTestingPipelineOptions(server), true /*
publishCounters */);
+ makeWorker(instructions, createTestingPipelineOptions(), true /*
publishCounters */);
worker.setMaxWorkItemCommitBytes(1000);
worker.start();
@@ -1245,7 +1251,6 @@ public class StreamingDataflowWorkerTest {
makeDoFnInstruction(new ChangeKeysFn(), 0, kvCoder),
makeSinkInstruction(kvCoder, 1));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
for (int i = 0; i < 2; i++) {
server
.whenGetWorkCalled()
@@ -1261,7 +1266,7 @@ public class StreamingDataflowWorkerTest {
}
StreamingDataflowWorker worker =
- makeWorker(instructions, createTestingPipelineOptions(server), true /*
publishCounters */);
+ makeWorker(instructions, createTestingPipelineOptions(), true /*
publishCounters */);
worker.start();
Map<Long, Windmill.WorkItemCommitRequest> result =
server.waitForAndGetCommits(4);
@@ -1302,7 +1307,6 @@ public class StreamingDataflowWorkerTest {
makeDoFnInstruction(new TestExceptionFn(), 0,
StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 1));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
server.setExpectedExceptionCount(1);
String keyString = keyStringForIndex(0);
server
@@ -1337,7 +1341,7 @@ public class StreamingDataflowWorkerTest {
Collections.singletonList(DEFAULT_WINDOW))));
StreamingDataflowWorker worker =
- makeWorker(instructions, createTestingPipelineOptions(server), true /*
publishCounters */);
+ makeWorker(instructions, createTestingPipelineOptions(), true /*
publishCounters */);
worker.start();
server.waitForEmptyWorkQueue();
@@ -1433,8 +1437,6 @@ public class StreamingDataflowWorkerTest {
addWindowsInstruction,
makeSinkInstruction(StringUtf8Coder.of(), 1));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
-
int timestamp1 = 0;
int timestamp2 = 1000000;
@@ -1444,7 +1446,7 @@ public class StreamingDataflowWorkerTest {
.thenReturn(makeInput(timestamp2, timestamp2));
StreamingDataflowWorker worker =
- makeWorker(instructions, createTestingPipelineOptions(server), false
/* publishCounters */);
+ makeWorker(instructions, createTestingPipelineOptions(), false /*
publishCounters */);
worker.start();
Map<Long, Windmill.WorkItemCommitRequest> result =
server.waitForAndGetCommits(2);
@@ -1561,10 +1563,8 @@ public class StreamingDataflowWorkerTest {
mergeWindowsInstruction,
makeSinkInstruction(groupedCoder, 1));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
-
StreamingDataflowWorker worker =
- makeWorker(instructions, createTestingPipelineOptions(server), false
/* publishCounters */);
+ makeWorker(instructions, createTestingPipelineOptions(), false /*
publishCounters */);
Map<String, String> nameMap = new HashMap<>();
nameMap.put("MergeWindowsStep", "MergeWindows");
worker.addStateNameMappings(nameMap);
@@ -1850,10 +1850,8 @@ public class StreamingDataflowWorkerTest {
makeDoFnInstruction(new PassthroughDoFn(), 1, groupedCoder),
makeSinkInstruction(groupedCoder, 2));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
-
StreamingDataflowWorker worker =
- makeWorker(instructions, createTestingPipelineOptions(server), false
/* publishCounters */);
+ makeWorker(instructions, createTestingPipelineOptions(), false /*
publishCounters */);
Map<String, String> nameMap = new HashMap<>();
nameMap.put("MergeWindowsStep", "MergeWindows");
worker.addStateNameMappings(nameMap);
@@ -2148,10 +2146,8 @@ public class StreamingDataflowWorkerTest {
mergeWindowsInstruction,
makeSinkInstruction(groupedCoder, 1));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
-
StreamingDataflowWorker worker =
- makeWorker(instructions, createTestingPipelineOptions(server), false
/* publishCounters */);
+ makeWorker(instructions, createTestingPipelineOptions(), false /*
publishCounters */);
Map<String, String> nameMap = new HashMap<>();
nameMap.put("MergeWindowsStep", "MergeWindows");
worker.addStateNameMappings(nameMap);
@@ -2172,8 +2168,7 @@ public class StreamingDataflowWorkerTest {
}
@Test
- public void testMergeSessionWindows() throws Exception {
- // Test a single late window.
+ public void testMergeSessionWindows_singleLateWindow() throws Exception {
runMergeSessionsActions(
Collections.singletonList(
new Action(
@@ -2183,7 +2178,10 @@ public class StreamingDataflowWorkerTest {
buildHold("/gAAAAAAAAAsK/+uhold", -1, true),
buildHold("/gAAAAAAAAAsK/+uextra", -1, true))
.withTimers(buildWatermarkTimer("/s/gAAAAAAAAAsK/+0",
3600010))));
+ }
+ @Test
+ public void testMergeSessionWindows() throws Exception {
// Test the behavior with an:
// - on time window that is triggered due to watermark advancement
// - a late window that is triggered immediately due to count
@@ -2298,11 +2296,10 @@ public class StreamingDataflowWorkerTest {
List<Integer> finalizeTracker = Lists.newArrayList();
TestCountingSource.setFinalizeTracker(finalizeTracker);
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
StreamingDataflowWorker worker =
makeWorker(
makeUnboundedSourcePipeline(),
- createTestingPipelineOptions(server),
+ createTestingPipelineOptions(),
false /* publishCounters */);
worker.start();
@@ -2465,11 +2462,10 @@ public class StreamingDataflowWorkerTest {
List<Integer> finalizeTracker = Lists.newArrayList();
TestCountingSource.setFinalizeTracker(finalizeTracker);
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
StreamingDataflowWorker worker =
makeWorker(
makeUnboundedSourcePipeline(),
- createTestingPipelineOptions(server),
+ createTestingPipelineOptions(),
true /* publishCounters */);
worker.start();
@@ -2579,8 +2575,7 @@ public class StreamingDataflowWorkerTest {
List<Integer> finalizeTracker = Lists.newArrayList();
TestCountingSource.setFinalizeTracker(finalizeTracker);
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ StreamingDataflowWorkerOptions options = createTestingPipelineOptions();
options.setWorkerCacheMb(0); // Disable state cache so it doesn't detect
retry.
StreamingDataflowWorker worker =
makeWorker(makeUnboundedSourcePipeline(), options, false /*
publishCounters */);
@@ -3094,10 +3089,10 @@ public class StreamingDataflowWorkerTest {
// 25. Read state as 42
// 26. Take counter reader checkpoint 2
// 27. CommitWork[2] (message 0:2, checkpoint 2)
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
+
server.setExpectedExceptionCount(2);
- DataflowPipelineOptions options = createTestingPipelineOptions(server);
+ DataflowPipelineOptions options = createTestingPipelineOptions();
options.setNumWorkers(1);
DataflowPipelineDebugOptions debugOptions =
options.as(DataflowPipelineDebugOptions.class);
debugOptions.setUnboundedReaderMaxElements(1);
@@ -3281,9 +3276,8 @@ public class StreamingDataflowWorkerTest {
makeDoFnInstruction(new FanoutFn(), 0, StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
- StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
+ StreamingDataflowWorker worker =
+ makeWorker(instructions, createTestingPipelineOptions(), true /*
publishCounters */);
worker.start();
server.whenGetWorkCalled().thenReturn(makeInput(0,
TimeUnit.MILLISECONDS.toMicros(0)));
@@ -3300,8 +3294,7 @@ public class StreamingDataflowWorkerTest {
makeDoFnInstruction(new SlowDoFn(), 0, StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ StreamingDataflowWorkerOptions options = createTestingPipelineOptions();
options.setActiveWorkRefreshPeriodMillis(100);
StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
worker.start();
@@ -3324,8 +3317,7 @@ public class StreamingDataflowWorkerTest {
makeDoFnInstruction(blockingFn, 0, StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ StreamingDataflowWorkerOptions options = createTestingPipelineOptions();
options.setActiveWorkRefreshPeriodMillis(100);
StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
worker.start();
@@ -3427,8 +3419,7 @@ public class StreamingDataflowWorkerTest {
new FakeSlowDoFn(clock, Duration.millis(1000)), 0,
StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ StreamingDataflowWorkerOptions options = createTestingPipelineOptions();
options.setActiveWorkRefreshPeriodMillis(100);
// A single-threaded worker processes work sequentially, leaving a second
work item in state
// QUEUED until the first work item is committed.
@@ -3470,8 +3461,7 @@ public class StreamingDataflowWorkerTest {
new FakeSlowDoFn(clock, Duration.millis(1000)), 0,
StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ StreamingDataflowWorkerOptions options = createTestingPipelineOptions();
options.setActiveWorkRefreshPeriodMillis(100);
StreamingDataflowWorker worker =
makeWorker(
@@ -3504,8 +3494,7 @@ public class StreamingDataflowWorkerTest {
makeDoFnInstruction(new ReadingDoFn(), 0, StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ StreamingDataflowWorkerOptions options = createTestingPipelineOptions();
options.setActiveWorkRefreshPeriodMillis(100);
StreamingDataflowWorker worker =
makeWorker(
@@ -3545,7 +3534,7 @@ public class StreamingDataflowWorkerTest {
makeSinkInstruction(StringUtf8Coder.of(), 0));
// Inject latency on the fake clock when the server receives a CommitWork
call.
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
+
server
.whenCommitWorkCalled()
.answerByDefault(
@@ -3553,7 +3542,7 @@ public class StreamingDataflowWorkerTest {
clock.sleep(Duration.millis(1000));
return Windmill.CommitWorkResponse.getDefaultInstance();
});
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ StreamingDataflowWorkerOptions options = createTestingPipelineOptions();
options.setActiveWorkRefreshPeriodMillis(100);
StreamingDataflowWorker worker =
makeWorker(
@@ -3588,8 +3577,7 @@ public class StreamingDataflowWorkerTest {
new FakeSlowDoFn(clock, Duration.millis(dofnWaitTimeMs)), 0,
StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ StreamingDataflowWorkerOptions options = createTestingPipelineOptions();
options.setActiveWorkRefreshPeriodMillis(100);
options.setNumberOfWorkerHarnessThreads(1);
StreamingDataflowWorker worker =
@@ -3641,8 +3629,7 @@ public class StreamingDataflowWorkerTest {
makeDoFnInstruction(new SlowDoFn(), 0, StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ StreamingDataflowWorkerOptions options = createTestingPipelineOptions();
options.setActiveWorkRefreshPeriodMillis(100);
StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
worker.start();
@@ -3678,8 +3665,7 @@ public class StreamingDataflowWorkerTest {
makeDoFnInstruction(new SlowDoFn(), 0, StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ StreamingDataflowWorkerOptions options = createTestingPipelineOptions();
options.setActiveWorkRefreshPeriodMillis(10);
StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
worker.start();
@@ -3715,12 +3701,11 @@ public class StreamingDataflowWorkerTest {
final int numMessagesInCustomSourceShard = 100000; // 100K input messages.
final int inflatedSizePerMessage = 10000; // x10k => 1GB total output size.
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
StreamingDataflowWorker worker =
makeWorker(
makeUnboundedSourcePipeline(
numMessagesInCustomSourceShard, new
InflateDoFn(inflatedSizePerMessage)),
- createTestingPipelineOptions(server),
+ createTestingPipelineOptions(),
false /* publishCounters */);
worker.start();
@@ -3802,9 +3787,8 @@ public class StreamingDataflowWorkerTest {
1,
GlobalWindow.Coder.INSTANCE));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
StreamingDataflowWorker worker =
- makeWorker(instructions, createTestingPipelineOptions(server), true /*
publishCounters */);
+ makeWorker(instructions, createTestingPipelineOptions(), true /*
publishCounters */);
worker.start();
// Test new key.
@@ -3870,8 +3854,7 @@ public class StreamingDataflowWorkerTest {
makeSourceInstruction(StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ StreamingDataflowWorkerOptions options = createTestingPipelineOptions();
options.setStuckCommitDurationMillis(2000);
StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
worker.start();
@@ -3904,14 +3887,12 @@ public class StreamingDataflowWorkerTest {
removeDynamicFields(result.get(1L)));
}
- private void runNumCommitThreadsTest(int configNumCommitThreads, int
expectedNumCommitThreads)
- throws Exception {
+ private void runNumCommitThreadsTest(int configNumCommitThreads, int
expectedNumCommitThreads) {
List<ParallelInstruction> instructions =
Arrays.asList(
makeSourceInstruction(StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ StreamingDataflowWorkerOptions options = createTestingPipelineOptions();
options.setWindmillServiceCommitThreads(configNumCommitThreads);
StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
worker.start();
@@ -3920,7 +3901,7 @@ public class StreamingDataflowWorkerTest {
}
@Test
- public void testDefaultNumCommitThreads() throws Exception {
+ public void testDefaultNumCommitThreads() {
if (streamingEngine) {
runNumCommitThreadsTest(1, 1);
runNumCommitThreadsTest(2, 2);
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java
index 5d17795b28f..515beba0c88 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java
@@ -33,7 +33,7 @@ import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
-import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse;
@@ -62,14 +62,14 @@ import org.mockito.Mockito;
@RunWith(JUnit4.class)
public class GrpcGetWorkerMetadataStreamTest {
- @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
private static final String IPV6_ADDRESS_1 =
"2001:db8:0000:bac5:0000:0000:fed0:81a2";
private static final String IPV6_ADDRESS_2 =
"2001:db8:0000:bac5:0000:0000:fed0:82a3";
+ private static final String AUTHENTICATING_SERVICE = "test.googleapis.com";
private static final List<WorkerMetadataResponse.Endpoint>
DIRECT_PATH_ENDPOINTS =
Lists.newArrayList(
WorkerMetadataResponse.Endpoint.newBuilder()
.setDirectEndpoint(IPV6_ADDRESS_1)
- .setWorkerToken("worker_token")
+ .setBackendWorkerToken("worker_token")
.build());
private static final Map<String, WorkerMetadataResponse.Endpoint>
GLOBAL_DATA_ENDPOINTS =
Maps.newHashMap();
@@ -83,6 +83,7 @@ public class GrpcGetWorkerMetadataStreamTest {
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
private final Set<AbstractWindmillStream<?, ?>> streamRegistry = new
HashSet<>();
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
private ManagedChannel inProcessChannel;
private GrpcGetWorkerMetadataStream stream;
@@ -93,8 +94,8 @@ public class GrpcGetWorkerMetadataStreamTest {
serviceRegistry.addService(getWorkerMetadataTestStub);
return GrpcGetWorkerMetadataStream.create(
responseObserver ->
- CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)
- .getWorkerMetadataStream(responseObserver),
+ CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel)
+ .getWorkerMetadata(responseObserver),
FluentBackoff.DEFAULT.backoff(),
StreamObserverFactory.direct(DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2,
1),
streamRegistry,
@@ -123,7 +124,7 @@ public class GrpcGetWorkerMetadataStreamTest {
"global_data",
WorkerMetadataResponse.Endpoint.newBuilder()
.setDirectEndpoint(IPV6_ADDRESS_1)
- .setWorkerToken("worker_token")
+ .setBackendWorkerToken("worker_token")
.build());
}
@@ -139,6 +140,7 @@ public class GrpcGetWorkerMetadataStreamTest {
.setMetadataVersion(1)
.addAllWorkEndpoints(DIRECT_PATH_ENDPOINTS)
.putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS)
+ .setExternalEndpoint(AUTHENTICATING_SERVICE)
.build();
TestWindmillEndpointsConsumer testWindmillEndpointsConsumer =
new TestWindmillEndpointsConsumer();
@@ -153,7 +155,9 @@ public class GrpcGetWorkerMetadataStreamTest {
assertThat(testWindmillEndpointsConsumer.windmillEndpoints)
.containsExactlyElementsIn(
DIRECT_PATH_ENDPOINTS.stream()
- .map(WindmillEndpoints.Endpoint::from)
+ .map(
+ endpointProto ->
+ WindmillEndpoints.Endpoint.from(endpointProto,
AUTHENTICATING_SERVICE))
.collect(Collectors.toList()));
}
@@ -164,6 +168,7 @@ public class GrpcGetWorkerMetadataStreamTest {
.setMetadataVersion(1)
.addAllWorkEndpoints(DIRECT_PATH_ENDPOINTS)
.putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS)
+ .setExternalEndpoint(AUTHENTICATING_SERVICE)
.build();
TestWindmillEndpointsConsumer testWindmillEndpointsConsumer =
Mockito.spy(new TestWindmillEndpointsConsumer());
@@ -187,6 +192,7 @@ public class GrpcGetWorkerMetadataStreamTest {
.setMetadataVersion(initialResponse.getMetadataVersion() + 1)
.addAllWorkEndpoints(newDirectPathEndpoints)
.putAllGlobalDataEndpoints(newGlobalDataEndpoints)
+ .setExternalEndpoint(AUTHENTICATING_SERVICE)
.build();
testStub.injectWorkerMetadata(newWorkMetadataResponse);
@@ -196,7 +202,9 @@ public class GrpcGetWorkerMetadataStreamTest {
assertThat(testWindmillEndpointsConsumer.windmillEndpoints)
.containsExactlyElementsIn(
newDirectPathEndpoints.stream()
- .map(WindmillEndpoints.Endpoint::from)
+ .map(
+ endpointProto ->
+ WindmillEndpoints.Endpoint.from(endpointProto,
AUTHENTICATING_SERVICE))
.collect(Collectors.toList()));
}
@@ -207,6 +215,7 @@ public class GrpcGetWorkerMetadataStreamTest {
.setMetadataVersion(2)
.addAllWorkEndpoints(DIRECT_PATH_ENDPOINTS)
.putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS)
+ .setExternalEndpoint(AUTHENTICATING_SERVICE)
.build();
TestWindmillEndpointsConsumer testWindmillEndpointsConsumer =
@@ -268,7 +277,8 @@ public class GrpcGetWorkerMetadataStreamTest {
}
private static class GetWorkerMetadataTestStub
- extends
CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase {
+ extends CloudWindmillMetadataServiceV1Alpha1Grpc
+ .CloudWindmillMetadataServiceV1Alpha1ImplBase {
private final TestGetWorkMetadataRequestObserver requestObserver;
private @Nullable StreamObserver<WorkerMetadataResponse> responseObserver;
@@ -277,7 +287,7 @@ public class GrpcGetWorkerMetadataStreamTest {
}
@Override
- public StreamObserver<WorkerMetadataRequest> getWorkerMetadataStream(
+ public StreamObserver<WorkerMetadataRequest> getWorkerMetadata(
StreamObserver<WorkerMetadataResponse> responseObserver) {
if (this.responseObserver == null) {
this.responseObserver = responseObserver;
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
index 15610462e01..37dc7eff917 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
@@ -73,6 +73,8 @@ import
org.apache.beam.runners.dataflow.worker.windmill.WindmillApplianceGrpc;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory;
+import
org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactory;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.CallOptions;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Channel;
@@ -87,6 +89,7 @@ import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.StatusRuntimeException;
import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.hamcrest.Matchers;
@@ -109,10 +112,13 @@ import org.slf4j.LoggerFactory;
})
public class GrpcWindmillServerTest {
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+ @Rule public GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ @Rule public ErrorCollector errorCollector = new ErrorCollector();
+
private static final Logger LOG =
LoggerFactory.getLogger(GrpcWindmillServerTest.class);
private static final int STREAM_CHUNK_SIZE = 2 << 20;
+ private final long clientId = 10L;
private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
- @Rule public ErrorCollector errorCollector = new ErrorCollector();
private Server server;
private GrpcWindmillServer client;
private int remainingErrors = 20;
@@ -128,7 +134,13 @@ public class GrpcWindmillServerTest {
.build()
.start();
- this.client = GrpcWindmillServer.newTestInstance(name, new ArrayList<>());
+ this.client =
+ GrpcWindmillServer.newTestInstance(
+ name,
+ new ArrayList<>(),
+ clientId,
+ new FakeWindmillStubFactory(
+ () ->
grpcCleanup.register(WindmillChannelFactory.inProcessChannel(name))));
}
@After
@@ -197,7 +209,9 @@ public class GrpcWindmillServerTest {
.build(),
testInterceptor);
- this.client =
GrpcWindmillServer.newApplianceTestInstance(inprocessChannel);
+ this.client =
+ GrpcWindmillServer.newApplianceTestInstance(
+ inprocessChannel, new FakeWindmillStubFactory(() ->
inprocessChannel));
Windmill.GetWorkResponse response1 =
client.getWork(GetWorkRequest.getDefaultInstance());
Windmill.GetWorkResponse response2 =
client.getWork(GetWorkRequest.getDefaultInstance());
@@ -346,6 +360,7 @@ public class GrpcWindmillServerTest {
.setJobId("job")
.setProjectId("project")
.setWorkerId("worker")
+ .setClientId(clientId)
.build()));
sawHeader = true;
} else {
@@ -555,6 +570,7 @@ public class GrpcWindmillServerTest {
.setJobId("job")
.setProjectId("project")
.setWorkerId("worker")
+ .setClientId(clientId)
.build()));
sawHeader = true;
LOG.info("Received header");
@@ -839,6 +855,7 @@ public class GrpcWindmillServerTest {
.setJobId("job")
.setProjectId("project")
.setWorkerId("worker")
+ .setClientId(clientId)
.build()));
sawHeader = true;
} else {
@@ -921,7 +938,10 @@ public class GrpcWindmillServerTest {
this.client =
GrpcWindmillServer.newTestInstance(
"TestServer",
-
Collections.singletonList("streaming_engine_send_new_heartbeat_requests"));
+
Collections.singletonList("streaming_engine_send_new_heartbeat_requests"),
+ clientId,
+ new FakeWindmillStubFactory(
+ () -> WindmillChannelFactory.inProcessChannel("TestServer")));
// This server records the heartbeats observed but doesn't respond.
final List<ComputationHeartbeatRequest> receivedHeartbeats = new
ArrayList<>();
@@ -945,6 +965,7 @@ public class GrpcWindmillServerTest {
.setJobId("job")
.setProjectId("project")
.setWorkerId("worker")
+ .setClientId(clientId)
.build()));
sawHeader = true;
} else {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java
index 4831726c49e..f755f033338 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java
@@ -37,7 +37,7 @@ import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
-import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest;
@@ -46,6 +46,7 @@ import
org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress;
import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory;
import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory;
+import
org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactory;
import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor;
import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
@@ -73,7 +74,6 @@ import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class StreamingEngineClientTest {
- @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
private static final WindmillServiceAddress DEFAULT_WINDMILL_SERVICE_ADDRESS
=
WindmillServiceAddress.create(HostAndPort.fromParts(WindmillChannelFactory.LOCALHOST,
443));
private static final ImmutableMap<String, WorkerMetadataResponse.Endpoint>
DEFAULT =
@@ -95,24 +95,25 @@ public class StreamingEngineClientTest {
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private final Set<ManagedChannel> channels = new HashSet<>();
private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
-
private final GrpcWindmillStreamFactory streamFactory =
spy(GrpcWindmillStreamFactory.of(JOB_HEADER).build());
private final WindmillStubFactory stubFactory =
- WindmillStubFactory.inProcessStubFactory(
- "StreamingEngineClientTest",
- name -> {
+ new FakeWindmillStubFactory(
+ () -> {
ManagedChannel channel =
-
grpcCleanup.register(WindmillChannelFactory.inProcessChannel(name));
+ grpcCleanup.register(
+
WindmillChannelFactory.inProcessChannel("StreamingEngineClientTest"));
channels.add(channel);
return channel;
});
private final GrpcDispatcherClient dispatcherClient =
- GrpcDispatcherClient.forTesting(stubFactory, new ArrayList<>(), new
HashSet<>());
+ GrpcDispatcherClient.forTesting(
+ stubFactory, new ArrayList<>(), new ArrayList<>(), new HashSet<>());
private final GetWorkBudgetDistributor getWorkBudgetDistributor =
spy(new TestGetWorkBudgetDistributor());
private final AtomicReference<StreamingEngineConnectionState> connections =
new AtomicReference<>(StreamingEngineConnectionState.EMPTY);
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
private Server fakeStreamingEngineServer;
private CountDownLatch getWorkerMetadataReady;
private GetWorkerMetadataTestStub fakeGetWorkerMetadataStub;
@@ -140,7 +141,7 @@ public class StreamingEngineClientTest {
}
private static WorkerMetadataResponse.Endpoint
metadataResponseEndpoint(String workerToken) {
- return
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken).build();
+ return
WorkerMetadataResponse.Endpoint.newBuilder().setBackendWorkerToken(workerToken).build();
}
@Before
@@ -269,16 +270,22 @@ public class StreamingEngineClientTest {
WorkerMetadataResponse.newBuilder()
.setMetadataVersion(1)
.addWorkEndpoints(
-
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken).build())
+ WorkerMetadataResponse.Endpoint.newBuilder()
+ .setBackendWorkerToken(workerToken)
+ .build())
.addWorkEndpoints(
-
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken2).build())
+ WorkerMetadataResponse.Endpoint.newBuilder()
+ .setBackendWorkerToken(workerToken2)
+ .build())
.putAllGlobalDataEndpoints(DEFAULT)
.build();
WorkerMetadataResponse secondWorkerMetadata =
WorkerMetadataResponse.newBuilder()
.setMetadataVersion(2)
.addWorkEndpoints(
-
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken3).build())
+ WorkerMetadataResponse.Endpoint.newBuilder()
+ .setBackendWorkerToken(workerToken3)
+ .build())
.putAllGlobalDataEndpoints(DEFAULT)
.build();
@@ -315,21 +322,27 @@ public class StreamingEngineClientTest {
WorkerMetadataResponse.newBuilder()
.setMetadataVersion(1)
.addWorkEndpoints(
-
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken).build())
+ WorkerMetadataResponse.Endpoint.newBuilder()
+ .setBackendWorkerToken(workerToken)
+ .build())
.putAllGlobalDataEndpoints(DEFAULT)
.build();
WorkerMetadataResponse secondWorkerMetadata =
WorkerMetadataResponse.newBuilder()
.setMetadataVersion(2)
.addWorkEndpoints(
-
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken2).build())
+ WorkerMetadataResponse.Endpoint.newBuilder()
+ .setBackendWorkerToken(workerToken2)
+ .build())
.putAllGlobalDataEndpoints(DEFAULT)
.build();
WorkerMetadataResponse thirdWorkerMetadata =
WorkerMetadataResponse.newBuilder()
.setMetadataVersion(3)
.addWorkEndpoints(
-
WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken3).build())
+ WorkerMetadataResponse.Endpoint.newBuilder()
+ .setBackendWorkerToken(workerToken3)
+ .build())
.putAllGlobalDataEndpoints(DEFAULT)
.build();
@@ -362,7 +375,8 @@ public class StreamingEngineClientTest {
}
private static class GetWorkerMetadataTestStub
- extends
CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase {
+ extends CloudWindmillMetadataServiceV1Alpha1Grpc
+ .CloudWindmillMetadataServiceV1Alpha1ImplBase {
private static final WorkerMetadataResponse CLOSE_ALL_STREAMS =
WorkerMetadataResponse.newBuilder().setMetadataVersion(100).build();
private final CountDownLatch ready;
@@ -373,7 +387,7 @@ public class StreamingEngineClientTest {
}
@Override
- public StreamObserver<WorkerMetadataRequest> getWorkerMetadataStream(
+ public StreamObserver<WorkerMetadataRequest> getWorkerMetadata(
StreamObserver<WorkerMetadataResponse> responseObserver) {
if (this.responseObserver == null) {
ready.countDown();
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java
new file mode 100644
index 00000000000..3dd40e5d5c7
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.testing;
+
+import java.util.function.Supplier;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Channel;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+
+@VisibleForTesting
+public final class FakeWindmillStubFactory implements WindmillStubFactory {
+ private final Supplier<Channel> channelFactory;
+
+ public FakeWindmillStubFactory(Supplier<Channel> channelFactory) {
+ this.channelFactory = channelFactory;
+ }
+
+ @Override
+ public CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub
+ createWindmillServiceStub(WindmillServiceAddress serviceAddress) {
+ return CloudWindmillServiceV1Alpha1Grpc.newStub(channelFactory.get());
+ }
+
+ @Override
+ public
CloudWindmillMetadataServiceV1Alpha1Grpc.CloudWindmillMetadataServiceV1Alpha1Stub
+ createWindmillMetadataServiceStub(WindmillServiceAddress serviceAddress)
{
+ return
CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(channelFactory.get());
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
index 0c824ca301b..4677ff9dcc9 100644
---
a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
+++
b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
@@ -847,6 +847,11 @@ message JobHeader {
optional string project_id = 2;
// Worker id is meant for logging only. Do not rely on it for other
decisions.
optional string worker_id = 3;
+ optional fixed64 client_id = 4;
+ optional string region_id = 5;
+ // Used by the user worker to communicate to a specific windmill worker. This
+ // is initially passed to the user worker via GetWorkerMetadata.
+ optional string backend_worker_token = 6;
}
message StreamingCommitRequestChunk {
@@ -902,14 +907,19 @@ message WorkerMetadataResponse {
message Endpoint {
// IPv6 address of a streaming engine windmill worker.
optional string direct_endpoint = 1;
- optional string worker_token = 2;
+ optional string backend_worker_token = 2;
+ optional int64 port = 3;
}
+
repeated Endpoint work_endpoints = 2;
// Maps from GlobalData tag to the endpoint that should be used for GetData
// calls to retrieve that global data.
map<string, Endpoint> global_data_endpoints = 3;
+ // Used to set gRPC authority.
+ optional string external_endpoint = 5;
+
reserved 4;
}
diff --git
a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto
b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto
index d9183e54e0d..101bae170db 100644
---
a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto
+++
b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto
@@ -33,10 +33,6 @@ service CloudWindmillServiceV1Alpha1 {
rpc GetWorkStream(stream .windmill.StreamingGetWorkRequest)
returns (stream .windmill.StreamingGetWorkResponseChunk);
- // Gets worker metadata. Response is a stream.
- rpc GetWorkerMetadataStream(stream .windmill.WorkerMetadataRequest)
- returns (stream .windmill.WorkerMetadataResponse);
-
// Gets data from Windmill.
rpc GetData(.windmill.GetDataRequest) returns(.windmill.GetDataResponse);
@@ -52,3 +48,9 @@ service CloudWindmillServiceV1Alpha1 {
rpc CommitWorkStream(stream .windmill.StreamingCommitWorkRequest)
returns (stream .windmill.StreamingCommitResponse);
}
+
+service CloudWindmillMetadataServiceV1Alpha1 {
+ // Gets worker metadata. Response is a stream.
+ rpc GetWorkerMetadata(stream.windmill.WorkerMetadataRequest)
+ returns (stream.windmill.WorkerMetadataResponse);
+}