scwhittle commented on code in PR #28835:
URL: https://github.com/apache/beam/pull/28835#discussion_r1372181561


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/StreamingEngineClient.java:
##########
@@ -0,0 +1,477 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+
+import com.google.auto.value.AutoValue;
+import com.google.errorprone.annotations.CheckReturnValue;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Random;
+import java.util.concurrent.BlockingDeque;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkerMetadataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.grpcclient.ThrottleTimer;
+import 
org.apache.beam.runners.dataflow.worker.windmill.util.WindmillGrpcStubFactory;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Client for StreamingEngine. Given a {@link GetWorkBudget}, divides the 
budget and starts the
+ * {@link 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream}(s).
+ */
+@CheckReturnValue
+@ThreadSafe
+public class StreamingEngineClient {
+  @VisibleForTesting static final int SCHEDULED_BUDGET_REFRESH_MILLIS = 100;
+  private static final Logger LOG = 
LoggerFactory.getLogger(StreamingEngineClient.class);
+  private static final String BUDGET_REFRESH_THREAD = "BudgetRefreshThread";
+  private static final String CONSUMER_WORKER_METADATA_THREAD = 
"ConsumeWorkerMetadataThread";
+
+  private final AtomicBoolean started;
+  private final JobHeader jobHeader;
+  private final GetWorkBudget totalGetWorkBudget;
+  private final StreamingEngineStreamFactory streamingEngineStreamFactory;
+  private final WorkItemReceiver workItemReceiver;
+  private final WindmillGrpcStubFactory windmillGrpcStubFactory;
+  private final GetWorkBudgetDistributor getWorkBudgetDistributor;
+  private final DispatcherClient dispatcherClient;
+  private final AtomicBoolean isBudgetRefreshPaused;
+  /** Writes are guarded by synchronization, reads are lock free. */
+  private final AtomicReference<StreamEngineConnectionState> connections;
+
+  /**
+   * Used to implement publish/subscribe behavior for triggering budget 
refreshes/redistribution.
+   * Subscriber {@link #budgetRefreshExecutor} will either redistribute the 
budget if a value has
+   * been triggered or more than {@link #SCHEDULED_BUDGET_REFRESH_MILLIS} time 
has passed.
+   */
+  private final BlockingDeque<TimeStampedTriggeredBudgetRefresh> 
budgetRefreshTrigger;
+
+  private final ScheduledExecutorService budgetRefreshExecutor;
+  private final AtomicReference<Instant> lastBudgetRefresh;
+  private final ThrottleTimer getWorkerMetadataThrottleTimer;
+  private final CountDownLatch getWorkerMetadataReady;
+  private final ExecutorService consumeWorkerMetadataExecutor;
+  private final long clientId;
+  /**
+   * Reference to {@link GetWorkerMetadataStream} that is lazily initialized, 
with its initial value
+   * being null.
+   */
+  private volatile @Nullable GetWorkerMetadataStream getWorkerMetadataStream;
+
+  private StreamingEngineClient(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      StreamingEngineStreamFactory streamingEngineStreamFactory,
+      WorkItemReceiver workItemReceiver,
+      WindmillGrpcStubFactory windmillGrpcStubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      DispatcherClient dispatcherClient) {
+    this.jobHeader = jobHeader;
+    this.totalGetWorkBudget = totalGetWorkBudget;
+    this.started = new AtomicBoolean();
+    this.streamingEngineStreamFactory = streamingEngineStreamFactory;
+    this.workItemReceiver = workItemReceiver;
+    this.connections = connections;
+    this.windmillGrpcStubFactory = windmillGrpcStubFactory;
+    this.getWorkBudgetDistributor = getWorkBudgetDistributor;
+    this.dispatcherClient = dispatcherClient;
+    this.isBudgetRefreshPaused = new AtomicBoolean(false);
+    this.budgetRefreshTrigger = new LinkedBlockingDeque<>();
+    this.getWorkerMetadataThrottleTimer = new ThrottleTimer();
+    this.budgetRefreshExecutor = 
createSingleThreadedExecutor(BUDGET_REFRESH_THREAD);
+    this.consumeWorkerMetadataExecutor =
+        createSingleThreadedExecutor(CONSUMER_WORKER_METADATA_THREAD);
+    this.getWorkerMetadataStream = null;
+    this.getWorkerMetadataReady = new CountDownLatch(1);
+    this.clientId = new Random().nextLong();
+    this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH);
+  }
+
+  /**
+   * Creates an instance of {@link StreamingEngineClient} and starts the {@link
+   * GetWorkerMetadataStream} with an RPC to the StreamingEngine backend. 
{@link
+   * GetWorkerMetadataStream} will populate {@link #connections} when a 
response is received. Calls
+   * to {@link #startAndCacheStreams()} will block until {@link #connections} 
are populated.
+   */
+  public static StreamingEngineClient create(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      StreamingEngineStreamFactory streamingEngineStreamFactory,
+      WorkItemReceiver workItemReceiver,
+      WindmillGrpcStubFactory windmillGrpcStubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      DispatcherClient dispatcherClient) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            new AtomicReference<>(StreamEngineConnectionState.EMPTY),
+            streamingEngineStreamFactory,
+            workItemReceiver,
+            windmillGrpcStubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient);
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  @VisibleForTesting
+  static StreamingEngineClient forTesting(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      StreamingEngineStreamFactory streamingEngineStreamFactory,
+      WorkItemReceiver workItemReceiver,
+      WindmillGrpcStubFactory windmillGrpcStubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      DispatcherClient dispatcherClient) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            connections,
+            streamingEngineStreamFactory,
+            workItemReceiver,
+            windmillGrpcStubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient);
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  private static ScheduledExecutorService createSingleThreadedExecutor(String 
threadName) {
+    return Executors.newSingleThreadScheduledExecutor(
+        new ThreadFactoryBuilder()
+            .setNameFormat(threadName)
+            .setUncaughtExceptionHandler(
+                (t, e) ->
+                    LOG.error(
+                        "{} failed due to uncaught exception during execution. 
", t.getName(), e))
+            .build());
+  }
+
+  /**
+   * Starts the streams with the {@link #connections} values. Does nothing if 
this has already been
+   * called.
+   *
+   * @throws IllegalArgumentException if trying to start before {@link 
#connections} are set with
+   *     {@link GetWorkerMetadataStream}.
+   */
+  public void startAndCacheStreams() {
+    // Do nothing if we have already initialized the initial streams.
+    if (!started.compareAndSet(false, true)) {
+      return;
+    }
+    waitForFirstStreamingEngineEndpoints();
+    StreamEngineConnectionState currentConnectionsState = connections.get();
+    Preconditions.checkState(
+        !StreamEngineConnectionState.EMPTY.equals(currentConnectionsState),
+        "Cannot start streams without connections.");
+    LOG.info("Starting initial GetWorkStreams with connections={}", 
currentConnectionsState);
+    ImmutableCollection<WindmillStreamSender> windmillStreamSenders =
+        currentConnectionsState.windmillStreams().values();
+    getWorkBudgetDistributor.distributeBudget(
+        currentConnectionsState.windmillStreams().values(), 
totalGetWorkBudget);
+    lastBudgetRefresh.compareAndSet(Instant.EPOCH, Instant.now());
+    windmillStreamSenders.forEach(WindmillStreamSender::startStreams);
+    startBudgetRefreshThreads();
+  }
+
+  private void waitForFirstStreamingEngineEndpoints() {
+    try {
+      getWorkerMetadataReady.await();
+    } catch (InterruptedException e) {
+      throw new StreamingEngineClientException(
+          "Error occurred waiting for StreamingEngine backend endpoints.", e);
+    }
+  }
+
+  /**
+   * {@link java.util.function.Consumer<WindmillEndpoints>} used to update 
{@link #connections} on
+   * new backend worker metadata.
+   */
+  private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints 
newWindmillEndpoints) {
+    isBudgetRefreshPaused.set(true);
+    LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints);
+    ImmutableMap<Endpoint, WindmillConnection> newWindmillConnections =
+        
createNewWindmillConnections(ImmutableSet.copyOf(newWindmillEndpoints.windmillEndpoints()));
+    ImmutableMap<WindmillConnection, WindmillStreamSender> newWindmillStreams =
+        
closeStaleStreamsAndCreateNewStreams(ImmutableSet.copyOf(newWindmillConnections.values()));
+    ImmutableMap<Endpoint, Supplier<GetDataStream>> newGlobalDataStreams =
+        createNewGlobalDataStreams(
+            
ImmutableSet.copyOf(newWindmillEndpoints.globalDataEndpoints().values()));
+
+    StreamEngineConnectionState newConnectionsState =
+        StreamEngineConnectionState.builder()
+            .setWindmillConnections(newWindmillConnections)
+            .setWindmillStreams(newWindmillStreams)
+            .setGlobalDataEndpoints(newWindmillEndpoints.globalDataEndpoints())
+            .setGlobalDataStreams(newGlobalDataStreams)
+            .build();
+
+    LOG.info(
+        "Setting new connections: {}. Previous connections: {}.",
+        newConnectionsState,
+        connections.get());
+    connections.set(newConnectionsState);
+    isBudgetRefreshPaused.set(false);
+
+    // On first worker metadata. Trigger
+    if (getWorkerMetadataReady.getCount() > 0) {
+      getWorkerMetadataReady.countDown();
+    } else {
+      
requestBudgetRefresh(TimeStampedTriggeredBudgetRefresh.Event.NEW_ENDPOINTS);
+    }
+  }
+
+  public ImmutableList<Long> getAndResetThrottleTimes() {
+    StreamEngineConnectionState currentConnections = connections.get();
+
+    ImmutableList<Long> keyedWorkStreamThrottleTimes =
+        currentConnections.windmillStreams().values().stream()
+            .map(WindmillStreamSender::getAndResetThrottleTime)
+            .collect(toImmutableList());
+
+    return ImmutableList.<Long>builder()
+        .add(getWorkerMetadataThrottleTimer.getAndResetThrottleTime())
+        .addAll(keyedWorkStreamThrottleTimes)
+        .build();
+  }
+
+  /** Starts {@link GetWorkerMetadataStream}. */
+  @SuppressWarnings({
+    "FutureReturnValueIgnored", // ignoring Future returned from 
Executor.submit()
+    "nullness" // Uninitialized value of getWorkerMetadataStream is null.
+  })
+  private void startGetWorkerMetadataStream() {
+    // We only want to set and start this value once.
+    if (getWorkerMetadataStream == null) {
+      synchronized (this) {
+        if (getWorkerMetadataStream == null) {
+          getWorkerMetadataStream =
+              streamingEngineStreamFactory.createGetWorkerMetadataStream(
+                  dispatcherClient.getDispatcherStub(),
+                  getWorkerMetadataThrottleTimer,
+                  endpoints ->
+                      consumeWorkerMetadataExecutor.submit(
+                          () -> consumeWindmillWorkerEndpoints(endpoints)));
+        }
+      }
+    }
+  }
+
+  @SuppressWarnings("FutureReturnValueIgnored")
+  private void startBudgetRefreshThreads() {
+    budgetRefreshExecutor.scheduleAtFixedRate(
+        this::refreshBudget,
+        SCHEDULED_BUDGET_REFRESH_MILLIS,
+        SCHEDULED_BUDGET_REFRESH_MILLIS,
+        TimeUnit.MILLISECONDS);
+  }
+
+  private void refreshBudget() {

Review Comment:
   It seems like if we wanted to trigger budget refresh this would wait to 
notice until the next poll?  Can we instead just have a thread hanging on 
consuming from the queue with a timeout? Then it would react immediately to 
triggering or at most redistribute every timeout period.
   
   This still has the issue though that it will cause things adding to the 
queue to block if the background thread is busy redistributing budget.  It 
would be nice if triggering was non-blocking and that multiple signals to 
trigger were collapsed.  This is what I was wondering if you could do with the 
Monitor class unless there was some other suitable class. You might be able to 
use AdvancingPhaser beam class. To trigger call arrive(), the monitoring loop 
woudl be something like:
   
   int phase = 0;
   while (true) {
     phase = phaser.awaitAdvanceInterruptibly(phase, 100, milliseconds);
     if (phase < 0) break; // phaser shutdown
     // either timed out or was triggered, run rebudgeting
   }
   
   
   awaitAdvanceInterruptibly
   --
   
   
   
[awaitAdvanceInterruptibly](https://docs.oracle.com/javase/8/docs/api/java/util/concurrent/Phaser.html#awaitAdvanceInterruptibly-int-)
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to