m-trieu commented on code in PR #28835:
URL: https://github.com/apache/beam/pull/28835#discussion_r1370988458


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/StreamingEngineClient.java:
##########
@@ -0,0 +1,477 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+
+import com.google.auto.value.AutoValue;
+import com.google.errorprone.annotations.CheckReturnValue;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Random;
+import java.util.concurrent.BlockingDeque;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkerMetadataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.grpcclient.ThrottleTimer;
+import 
org.apache.beam.runners.dataflow.worker.windmill.util.WindmillGrpcStubFactory;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Client for StreamingEngine. Given a {@link GetWorkBudget}, divides the 
budget and starts the
+ * {@link 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream}(s).
+ */
+@CheckReturnValue
+@ThreadSafe
+public class StreamingEngineClient {
+  @VisibleForTesting static final int SCHEDULED_BUDGET_REFRESH_MILLIS = 100;
+  private static final Logger LOG = 
LoggerFactory.getLogger(StreamingEngineClient.class);
+  private static final String BUDGET_REFRESH_THREAD = "BudgetRefreshThread";
+  private static final String CONSUMER_WORKER_METADATA_THREAD = 
"ConsumeWorkerMetadataThread";
+
+  private final AtomicBoolean started;
+  private final JobHeader jobHeader;
+  private final GetWorkBudget totalGetWorkBudget;
+  private final StreamingEngineStreamFactory streamingEngineStreamFactory;
+  private final WorkItemReceiver workItemReceiver;
+  private final WindmillGrpcStubFactory windmillGrpcStubFactory;
+  private final GetWorkBudgetDistributor getWorkBudgetDistributor;
+  private final DispatcherClient dispatcherClient;
+  private final AtomicBoolean isBudgetRefreshPaused;
+  /** Writes are guarded by synchronization, reads are lock free. */
+  private final AtomicReference<StreamEngineConnectionState> connections;
+
+  /**
+   * Used to implement publish/subscribe behavior for triggering budget 
refreshes/redistribution.
+   * Subscriber {@link #budgetRefreshExecutor} will either redistribute the 
budget if a value has
+   * been triggered or more than {@link #SCHEDULED_BUDGET_REFRESH_MILLIS} time 
has passed.
+   */
+  private final BlockingDeque<TimeStampedTriggeredBudgetRefresh> 
budgetRefreshTrigger;
+
+  private final ScheduledExecutorService budgetRefreshExecutor;
+  private final AtomicReference<Instant> lastBudgetRefresh;
+  private final ThrottleTimer getWorkerMetadataThrottleTimer;
+  private final CountDownLatch getWorkerMetadataReady;
+  private final ExecutorService consumeWorkerMetadataExecutor;
+  private final long clientId;
+  /**
+   * Reference to {@link GetWorkerMetadataStream} that is lazily initialized, 
with its initial value
+   * being null.
+   */
+  private volatile @Nullable GetWorkerMetadataStream getWorkerMetadataStream;
+
+  private StreamingEngineClient(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      StreamingEngineStreamFactory streamingEngineStreamFactory,
+      WorkItemReceiver workItemReceiver,
+      WindmillGrpcStubFactory windmillGrpcStubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      DispatcherClient dispatcherClient) {
+    this.jobHeader = jobHeader;
+    this.totalGetWorkBudget = totalGetWorkBudget;
+    this.started = new AtomicBoolean();
+    this.streamingEngineStreamFactory = streamingEngineStreamFactory;
+    this.workItemReceiver = workItemReceiver;
+    this.connections = connections;
+    this.windmillGrpcStubFactory = windmillGrpcStubFactory;
+    this.getWorkBudgetDistributor = getWorkBudgetDistributor;
+    this.dispatcherClient = dispatcherClient;
+    this.isBudgetRefreshPaused = new AtomicBoolean(false);
+    this.budgetRefreshTrigger = new LinkedBlockingDeque<>();
+    this.getWorkerMetadataThrottleTimer = new ThrottleTimer();
+    this.budgetRefreshExecutor = 
createSingleThreadedExecutor(BUDGET_REFRESH_THREAD);
+    this.consumeWorkerMetadataExecutor =
+        createSingleThreadedExecutor(CONSUMER_WORKER_METADATA_THREAD);
+    this.getWorkerMetadataStream = null;
+    this.getWorkerMetadataReady = new CountDownLatch(1);
+    this.clientId = new Random().nextLong();
+    this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH);
+  }
+
+  /**
+   * Creates an instance of {@link StreamingEngineClient} and starts the {@link
+   * GetWorkerMetadataStream} with an RPC to the StreamingEngine backend. 
{@link
+   * GetWorkerMetadataStream} will populate {@link #connections} when a 
response is received. Calls
+   * to {@link #startAndCacheStreams()} will block until {@link #connections} 
are populated.
+   */
+  public static StreamingEngineClient create(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      StreamingEngineStreamFactory streamingEngineStreamFactory,
+      WorkItemReceiver workItemReceiver,
+      WindmillGrpcStubFactory windmillGrpcStubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      DispatcherClient dispatcherClient) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            new AtomicReference<>(StreamEngineConnectionState.EMPTY),
+            streamingEngineStreamFactory,
+            workItemReceiver,
+            windmillGrpcStubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient);
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  @VisibleForTesting
+  static StreamingEngineClient forTesting(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      StreamingEngineStreamFactory streamingEngineStreamFactory,
+      WorkItemReceiver workItemReceiver,
+      WindmillGrpcStubFactory windmillGrpcStubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      DispatcherClient dispatcherClient) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            connections,
+            streamingEngineStreamFactory,
+            workItemReceiver,
+            windmillGrpcStubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient);
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  private static ScheduledExecutorService createSingleThreadedExecutor(String 
threadName) {
+    return Executors.newSingleThreadScheduledExecutor(
+        new ThreadFactoryBuilder()
+            .setNameFormat(threadName)
+            .setUncaughtExceptionHandler(
+                (t, e) ->
+                    LOG.error(
+                        "{} failed due to uncaught exception during execution. 
", t.getName(), e))
+            .build());
+  }
+
+  /**
+   * Starts the streams with the {@link #connections} values. Does nothing if 
this has already been
+   * called.
+   *
+   * @throws IllegalArgumentException if trying to start before {@link 
#connections} are set with
+   *     {@link GetWorkerMetadataStream}.
+   */
+  public void startAndCacheStreams() {
+    // Do nothing if we have already initialized the initial streams.
+    if (!started.compareAndSet(false, true)) {
+      return;
+    }
+    waitForFirstStreamingEngineEndpoints();
+    StreamEngineConnectionState currentConnectionsState = connections.get();
+    Preconditions.checkState(
+        !StreamEngineConnectionState.EMPTY.equals(currentConnectionsState),
+        "Cannot start streams without connections.");
+    LOG.info("Starting initial GetWorkStreams with connections={}", 
currentConnectionsState);
+    ImmutableCollection<WindmillStreamSender> windmillStreamSenders =
+        currentConnectionsState.windmillStreams().values();
+    getWorkBudgetDistributor.distributeBudget(
+        currentConnectionsState.windmillStreams().values(), 
totalGetWorkBudget);
+    lastBudgetRefresh.compareAndSet(Instant.EPOCH, Instant.now());
+    windmillStreamSenders.forEach(WindmillStreamSender::startStreams);
+    startBudgetRefreshThreads();
+  }
+
+  private void waitForFirstStreamingEngineEndpoints() {
+    try {
+      getWorkerMetadataReady.await();
+    } catch (InterruptedException e) {
+      throw new StreamingEngineClientException(
+          "Error occurred waiting for StreamingEngine backend endpoints.", e);
+    }
+  }
+
+  /**
+   * {@link java.util.function.Consumer<WindmillEndpoints>} used to update 
{@link #connections} on
+   * new backend worker metadata.
+   */
+  private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints 
newWindmillEndpoints) {
+    isBudgetRefreshPaused.set(true);
+    LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints);
+    ImmutableMap<Endpoint, WindmillConnection> newWindmillConnections =
+        
createNewWindmillConnections(ImmutableSet.copyOf(newWindmillEndpoints.windmillEndpoints()));
+    ImmutableMap<WindmillConnection, WindmillStreamSender> newWindmillStreams =
+        
closeStaleStreamsAndCreateNewStreams(ImmutableSet.copyOf(newWindmillConnections.values()));
+    ImmutableMap<Endpoint, Supplier<GetDataStream>> newGlobalDataStreams =
+        createNewGlobalDataStreams(
+            
ImmutableSet.copyOf(newWindmillEndpoints.globalDataEndpoints().values()));
+
+    StreamEngineConnectionState newConnectionsState =
+        StreamEngineConnectionState.builder()
+            .setWindmillConnections(newWindmillConnections)
+            .setWindmillStreams(newWindmillStreams)
+            .setGlobalDataEndpoints(newWindmillEndpoints.globalDataEndpoints())
+            .setGlobalDataStreams(newGlobalDataStreams)
+            .build();
+
+    LOG.info(
+        "Setting new connections: {}. Previous connections: {}.",
+        newConnectionsState,
+        connections.get());
+    connections.set(newConnectionsState);
+    isBudgetRefreshPaused.set(false);
+
+    // On first worker metadata. Trigger
+    if (getWorkerMetadataReady.getCount() > 0) {
+      getWorkerMetadataReady.countDown();
+    } else {
+      
requestBudgetRefresh(TimeStampedTriggeredBudgetRefresh.Event.NEW_ENDPOINTS);
+    }
+  }
+
+  public ImmutableList<Long> getAndResetThrottleTimes() {
+    StreamEngineConnectionState currentConnections = connections.get();
+
+    ImmutableList<Long> keyedWorkStreamThrottleTimes =
+        currentConnections.windmillStreams().values().stream()
+            .map(WindmillStreamSender::getAndResetThrottleTime)
+            .collect(toImmutableList());
+
+    return ImmutableList.<Long>builder()
+        .add(getWorkerMetadataThrottleTimer.getAndResetThrottleTime())
+        .addAll(keyedWorkStreamThrottleTimes)
+        .build();
+  }
+
+  /** Starts {@link GetWorkerMetadataStream}. */
+  @SuppressWarnings({
+    "FutureReturnValueIgnored", // ignoring Future returned from 
Executor.submit()
+    "nullness" // Uninitialized value of getWorkerMetadataStream is null.
+  })
+  private void startGetWorkerMetadataStream() {
+    // We only want to set and start this value once.
+    if (getWorkerMetadataStream == null) {
+      synchronized (this) {
+        if (getWorkerMetadataStream == null) {
+          getWorkerMetadataStream =
+              streamingEngineStreamFactory.createGetWorkerMetadataStream(
+                  dispatcherClient.getDispatcherStub(),
+                  getWorkerMetadataThrottleTimer,
+                  endpoints ->
+                      consumeWorkerMetadataExecutor.submit(
+                          () -> consumeWindmillWorkerEndpoints(endpoints)));
+        }
+      }
+    }
+  }
+
+  @SuppressWarnings("FutureReturnValueIgnored")
+  private void startBudgetRefreshThreads() {
+    budgetRefreshExecutor.scheduleAtFixedRate(
+        this::refreshBudget,
+        SCHEDULED_BUDGET_REFRESH_MILLIS,
+        SCHEDULED_BUDGET_REFRESH_MILLIS,
+        TimeUnit.MILLISECONDS);
+  }
+
+  private void refreshBudget() {

Review Comment:
   how about this @scwhittle 
   Changed to 1 thread that gets scheduled at a fixed interval with an initial 
delay
   
   This will poll a BlockDeque for a triggered budget refresh without blocking, 
if a there was a trigger that happened after the most recent refresh, it will 
redistribute the budget.
   
   Else if enough time has passed and budget has not been distributed, it will 
redistribute the budget (this is kind of like the scheduling mechanism) 
   
   if lgty, i will go forward and modify the tests :)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to