scwhittle commented on code in PR #32775:
URL: https://github.com/apache/beam/pull/32775#discussion_r1804316412


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java:
##########
@@ -280,29 +300,30 @@ private synchronized void 
consumeWindmillWorkerEndpoints(WindmillEndpoints newWi
   }
 
   /** Close the streams that are no longer valid asynchronously. */
-  @SuppressWarnings("FutureReturnValueIgnored")
-  private void closeStaleStreams(WindmillEndpoints newWindmillEndpoints) {
+  private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) {
     StreamingEngineBackends currentBackends = backends.get();
-    ImmutableMap<Endpoint, WindmillStreamSender> currentWindmillStreams =
-        currentBackends.windmillStreams();
-    currentWindmillStreams.entrySet().stream()
+    currentBackends.windmillStreams().entrySet().stream()
         .filter(
             connectionAndStream ->
                 
!newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey()))
         .forEach(
-            entry ->
-                CompletableFuture.runAsync(
-                    () -> closeStreamSender(entry.getKey(), entry.getValue()),
-                    windmillStreamManager));
+            entry -> {
+              CompletableFuture<Void> ignored =

Review Comment:
   is this any different than just executing directly? if not it seems simpler 
to avoid the future.
   
   windmillStreamManager.execute(
     () -> closeStreamSender(entry.getKey(), entry.getValue()))



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java:
##########
@@ -445,10 +400,14 @@ private void waitForBudgetDistribution() throws 
InterruptedException {
       getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS);

Review Comment:
   seems like this should return value so tests can assert if they are awaiting?



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+  private static final Windmill.JobHeader TEST_JOB_HEADER =
+      Windmill.JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcDirectGetWorkStreamTest";
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+  private ManagedChannel inProcessChannel;
+  private GrpcDirectGetWorkStream stream;
+
+  private static Windmill.StreamingGetWorkRequestExtension 
extension(GetWorkBudget budget) {
+    return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+        .setMaxItems(budget.items())
+        .setMaxBytes(budget.bytes())
+        .build();
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+    checkNotNull(stream).shutdown();
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(

Review Comment:
   This is a lot of variants of the create
   
   can we get rid of this one? easy for test to call new ThrottleTimer if they 
don't care.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+  private static final Windmill.JobHeader TEST_JOB_HEADER =
+      Windmill.JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcDirectGetWorkStreamTest";
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+  private ManagedChannel inProcessChannel;
+  private GrpcDirectGetWorkStream stream;
+
+  private static Windmill.StreamingGetWorkRequestExtension 
extension(GetWorkBudget budget) {
+    return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+        .setMaxItems(budget.items())
+        .setMaxBytes(budget.bytes())
+        .build();
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+    checkNotNull(stream).shutdown();
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+    return createGetWorkStream(testStub, initialGetWorkBudget, new 
ThrottleTimer());
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      WorkItemScheduler workItemScheduler) {
+    return createGetWorkStream(
+        testStub, initialGetWorkBudget, new ThrottleTimer(), 
workItemScheduler);
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      ThrottleTimer throttleTimer) {
+    return createGetWorkStream(
+        testStub,
+        initialGetWorkBudget,
+        throttleTimer,
+        (workItem, watermarks, processingContext, getWorkStreamLatencies) -> 
{});
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      ThrottleTimer throttleTimer,
+      WorkItemScheduler workItemScheduler) {
+    serviceRegistry.addService(testStub);
+    return (GrpcDirectGetWorkStream)
+        GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+            .build()
+            .createDirectGetWorkStream(
+                WindmillConnection.builder()
+                    
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+                    .build(),
+                Windmill.GetWorkRequest.newBuilder()
+                    .setClientId(TEST_JOB_HEADER.getClientId())
+                    .setJobId(TEST_JOB_HEADER.getJobId())
+                    .setProjectId(TEST_JOB_HEADER.getProjectId())
+                    .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+                    .setMaxItems(initialGetWorkBudget.items())
+                    .setMaxBytes(initialGetWorkBudget.bytes())
+                    .build(),
+                throttleTimer,
+                mock(HeartbeatSender.class),
+                mock(GetDataClient.class),
+                mock(WorkCommitter.class),
+                workItemScheduler);
+  }
+
+  private Windmill.StreamingGetWorkResponseChunk 
createResponse(Windmill.WorkItem workItem) {
+    return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+        .setStreamId(1L)
+        .setComputationMetadata(
+            Windmill.ComputationWorkItemMetadata.newBuilder()
+                .setComputationId("compId")
+                .setInputDataWatermark(1L)
+                .setDependentRealtimeInputWatermark(1L)
+                .build())
+        .setSerializedWorkItem(workItem.toByteString())
+        .setRemainingBytesForWorkItem(0)
+        .build();
+  }
+
+  @Test
+  public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+      throws InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+    GetWorkBudget newBudget = 
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+    stream.setBudget(newBudget);
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    // Header and extension.
+    assertThat(requestObserver.sent()).hasSize(expectedRequests);
+    assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension())

Review Comment:
   woudl be better to verify the header as well.
   ditto for others



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+  private static final Windmill.JobHeader TEST_JOB_HEADER =
+      Windmill.JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcDirectGetWorkStreamTest";
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+  private ManagedChannel inProcessChannel;
+  private GrpcDirectGetWorkStream stream;
+
+  private static Windmill.StreamingGetWorkRequestExtension 
extension(GetWorkBudget budget) {
+    return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+        .setMaxItems(budget.items())
+        .setMaxBytes(budget.bytes())
+        .build();
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+    checkNotNull(stream).shutdown();
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+    return createGetWorkStream(testStub, initialGetWorkBudget, new 
ThrottleTimer());
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      WorkItemScheduler workItemScheduler) {
+    return createGetWorkStream(
+        testStub, initialGetWorkBudget, new ThrottleTimer(), 
workItemScheduler);
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      ThrottleTimer throttleTimer) {
+    return createGetWorkStream(
+        testStub,
+        initialGetWorkBudget,
+        throttleTimer,
+        (workItem, watermarks, processingContext, getWorkStreamLatencies) -> 
{});
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      ThrottleTimer throttleTimer,
+      WorkItemScheduler workItemScheduler) {
+    serviceRegistry.addService(testStub);
+    return (GrpcDirectGetWorkStream)
+        GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+            .build()
+            .createDirectGetWorkStream(
+                WindmillConnection.builder()
+                    
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+                    .build(),
+                Windmill.GetWorkRequest.newBuilder()
+                    .setClientId(TEST_JOB_HEADER.getClientId())
+                    .setJobId(TEST_JOB_HEADER.getJobId())
+                    .setProjectId(TEST_JOB_HEADER.getProjectId())
+                    .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+                    .setMaxItems(initialGetWorkBudget.items())
+                    .setMaxBytes(initialGetWorkBudget.bytes())
+                    .build(),
+                throttleTimer,
+                mock(HeartbeatSender.class),
+                mock(GetDataClient.class),
+                mock(WorkCommitter.class),
+                workItemScheduler);
+  }
+
+  private Windmill.StreamingGetWorkResponseChunk 
createResponse(Windmill.WorkItem workItem) {
+    return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+        .setStreamId(1L)
+        .setComputationMetadata(
+            Windmill.ComputationWorkItemMetadata.newBuilder()
+                .setComputationId("compId")
+                .setInputDataWatermark(1L)
+                .setDependentRealtimeInputWatermark(1L)
+                .build())
+        .setSerializedWorkItem(workItem.toByteString())
+        .setRemainingBytesForWorkItem(0)
+        .build();
+  }
+
+  @Test
+  public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+      throws InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+    GetWorkBudget newBudget = 
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+    stream.setBudget(newBudget);
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    // Header and extension.
+    assertThat(requestObserver.sent()).hasSize(expectedRequests);
+    assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension())
+        .isEqualTo(extension(newBudget));
+  }
+
+  @Test
+  public void testSetBudget_computesAndSendsCorrectExtension_existingBudget()
+      throws InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    GetWorkBudget initialBudget = 
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+    stream = createGetWorkStream(testStub, initialBudget);
+    GetWorkBudget newBudget = 
GetWorkBudget.builder().setItems(100).setBytes(100).build();
+    stream.setBudget(newBudget);
+    GetWorkBudget diff = newBudget.subtract(initialBudget);
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Header and extension.
+    assertThat(requests).hasSize(expectedRequests);
+    
assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff));
+  }
+
+  @Test
+  public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh()
+      throws InterruptedException {
+    int expectedRequests = 1;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream =
+        createGetWorkStream(
+            testStub,
+            
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+    
stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build());
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Assert that the extension was never sent, only the header.
+    assertThat(requests).hasSize(expectedRequests);
+    assertThat(Iterables.getOnlyElement(requests).getRequest())

Review Comment:
   this maybe will never succeed since getRequest will return the default 
object? Or perhaps it throws?
   
   in either case you could just the type of the one-of via 
assertTrue(...hasRequest()) or something similar. But it could be better to 
have a matcher that you can use here and other tests that lets you verify the 
budget within the initial request.
   
   ditto for other instanceof checks



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java:
##########
@@ -40,147 +38,39 @@ public class EvenGetWorkBudgetDistributorTest {
   @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
   @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
 
-  private static GetWorkBudgetDistributor 
createBudgetDistributor(GetWorkBudget activeWorkBudget) {
-    return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget);
-  }
+  private static GetWorkBudgetSpender createGetWorkBudgetOwner() {
+    // Lambdas are final and cannot be spied.
+    return spy(
+        new GetWorkBudgetSpender() {
 
-  private static GetWorkBudgetDistributor createBudgetDistributor(long 
activeWorkItemsAndBytes) {
-    return createBudgetDistributor(
-        GetWorkBudget.builder()
-            .setItems(activeWorkItemsAndBytes)
-            .setBytes(activeWorkItemsAndBytes)
-            .build());
+          @Override
+          public void setBudget(long items, long bytes) {}
+        });
   }
 
   @Test
   public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() {
-    createBudgetDistributor(1L)
+    GetWorkBudgetDistributors.distributeEvenly()
         .distributeBudget(
             ImmutableList.of(), 
GetWorkBudget.builder().setItems(10L).setBytes(10L).build());
   }
 
   @Test
   public void testDistributeBudget_doesNothingWithNoBudget() {
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()));
-    createBudgetDistributor(1L)
+    GetWorkBudgetSpender getWorkBudgetSpender = 
spy(createGetWorkBudgetOwner());
+    GetWorkBudgetDistributors.distributeEvenly()
         .distributeBudget(ImmutableList.of(getWorkBudgetSpender), 
GetWorkBudget.noBudget());
     verifyNoInteractions(getWorkBudgetSpender);
   }
 
-  @Test
-  public void 
testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighNoActiveWork()
 {
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        spy(
-            createGetWorkBudgetOwnerWithRemainingBudgetOf(
-                GetWorkBudget.builder().setItems(10L).setBytes(10L).build()));
-    createBudgetDistributor(0L)
-        .distributeBudget(
-            ImmutableList.of(getWorkBudgetSpender),
-            GetWorkBudget.builder().setItems(10L).setBytes(10L).build());
-
-    verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong());
-  }
-
-  @Test
-  public void
-      
testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighWithActiveWork()
 {
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        spy(
-            createGetWorkBudgetOwnerWithRemainingBudgetOf(
-                GetWorkBudget.builder().setItems(5L).setBytes(5L).build()));
-    createBudgetDistributor(10L)
-        .distributeBudget(
-            ImmutableList.of(getWorkBudgetSpender),
-            GetWorkBudget.builder().setItems(20L).setBytes(20L).build());
-
-    verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong());
-  }
-
-  @Test
-  public void
-      
testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithNoActiveWork()
 {
-    GetWorkBudget streamRemainingBudget =
-        GetWorkBudget.builder().setItems(1L).setBytes(10L).build();
-    GetWorkBudget totalGetWorkBudget = 
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
-    createBudgetDistributor(0L)
-        .distributeBudget(ImmutableList.of(getWorkBudgetSpender), 
totalGetWorkBudget);
-
-    verify(getWorkBudgetSpender, times(1))
-        .adjustBudget(
-            eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
-            eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
-  }
-
-  @Test
-  public void
-      
testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithActiveWork()
 {
-    GetWorkBudget streamRemainingBudget =
-        GetWorkBudget.builder().setItems(1L).setBytes(10L).build();
-    GetWorkBudget totalGetWorkBudget = 
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
-    long activeWorkItemsAndBytes = 2L;
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
-    createBudgetDistributor(activeWorkItemsAndBytes)
-        .distributeBudget(ImmutableList.of(getWorkBudgetSpender), 
totalGetWorkBudget);
-
-    verify(getWorkBudgetSpender, times(1))
-        .adjustBudget(
-            eq(
-                totalGetWorkBudget.items()
-                    - streamRemainingBudget.items()
-                    - activeWorkItemsAndBytes),
-            eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
-  }
-
-  @Test
-  public void 
testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowNoActiveWork()
 {
-    GetWorkBudget streamRemainingBudget =
-        GetWorkBudget.builder().setItems(10L).setBytes(1L).build();
-    GetWorkBudget totalGetWorkBudget = 
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
-    createBudgetDistributor(0L)
-        .distributeBudget(ImmutableList.of(getWorkBudgetSpender), 
totalGetWorkBudget);
-
-    verify(getWorkBudgetSpender, times(1))
-        .adjustBudget(
-            eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
-            eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
-  }
-
-  @Test
-  public void
-      
testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowWithActiveWork()
 {
-    GetWorkBudget streamRemainingBudget =
-        GetWorkBudget.builder().setItems(10L).setBytes(1L).build();
-    GetWorkBudget totalGetWorkBudget = 
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
-    long activeWorkItemsAndBytes = 2L;
-
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
-    createBudgetDistributor(activeWorkItemsAndBytes)
-        .distributeBudget(ImmutableList.of(getWorkBudgetSpender), 
totalGetWorkBudget);
-
-    verify(getWorkBudgetSpender, times(1))
-        .adjustBudget(
-            eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
-            eq(
-                totalGetWorkBudget.bytes()
-                    - streamRemainingBudget.bytes()
-                    - activeWorkItemsAndBytes));
-  }
-
   @Test
   public void testDistributeBudget_distributesBudgetEvenlyIfPossible() {
     long totalItemsAndBytes = 10L;

Review Comment:
   would be better to have different items and bytes values to confirm 
distributor doesn't mix them up internally



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java:
##########
@@ -61,38 +43,15 @@ public <T extends GetWorkBudgetSpender> void 
distributeBudget(
       return;
     }
 
-    Map<T, GetWorkBudget> desiredBudgets = computeDesiredBudgets(budgetOwners, 
getWorkBudget);
-
-    for (Entry<T, GetWorkBudget> streamAndDesiredBudget : 
desiredBudgets.entrySet()) {
-      GetWorkBudgetSpender getWorkBudgetSpender = 
streamAndDesiredBudget.getKey();
-      GetWorkBudget desired = streamAndDesiredBudget.getValue();
-      GetWorkBudget remaining = getWorkBudgetSpender.remainingBudget();
-      if (isBelowFiftyPercentOfTarget(remaining, desired)) {
-        GetWorkBudget adjustment = desired.subtract(remaining);
-        getWorkBudgetSpender.adjustBudget(adjustment);
-      }
-    }
+    GetWorkBudget budgetPerStream = computeDesiredBudgets(budgetSpenders, 
getWorkBudget);
+    budgetSpenders.forEach(getWorkBudgetSpender -> 
getWorkBudgetSpender.setBudget(budgetPerStream));
   }
 
-  private <T extends GetWorkBudgetSpender> ImmutableMap<T, GetWorkBudget> 
computeDesiredBudgets(
+  private <T extends GetWorkBudgetSpender> GetWorkBudget computeDesiredBudgets(

Review Comment:
   nit: maybe name computeDesiredPerStreamBudget? or just inline?
   
   budgets makes it sound like it is computing multiple.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java:
##########
@@ -267,25 +277,100 @@ protected void startThrottleTimer() {
   }
 
   @Override
-  public void adjustBudget(long itemsDelta, long bytesDelta) {
-    GetWorkBudget adjustment =
-        nextBudgetAdjustment
-            // Get the current value, and reset the nextBudgetAdjustment. This 
will be set again
-            // when adjustBudget is called.
-            .getAndUpdate(unused -> GetWorkBudget.noBudget())
-            .apply(itemsDelta, bytesDelta);
-    sendRequestExtension(adjustment);
+  public void setBudget(long newItems, long newBytes) {
+    GetWorkBudget currentMaxGetWorkBudget =
+        maxGetWorkBudget.updateAndGet(
+            ignored -> 
GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build());
+    GetWorkBudget extension = 
budgetTracker.computeBudgetExtension(currentMaxGetWorkBudget);
+    maybeSendRequestExtension(extension);
   }
 
-  @Override
-  public GetWorkBudget remainingBudget() {
-    // Snapshot the current budgets.
-    GetWorkBudget currentPendingResponseBudget = pendingResponseBudget.get();
-    GetWorkBudget currentNextBudgetAdjustment = nextBudgetAdjustment.get();
-    GetWorkBudget currentInflightBudget = inFlightBudget.get();
-
-    return currentPendingResponseBudget
-        .apply(currentNextBudgetAdjustment)
-        .apply(currentInflightBudget);
+  /**
+   * Tracks sent and received GetWorkBudget and uses this information to 
generate request
+   * extensions.
+   */
+  @AutoValue
+  abstract static class GetWorkBudgetTracker {
+
+    private static GetWorkBudgetTracker create() {
+      return new AutoValue_GrpcDirectGetWorkStream_GetWorkBudgetTracker(
+          new AtomicLong(), new AtomicLong(), new AtomicLong(), new 
AtomicLong());
+    }
+
+    abstract AtomicLong itemsRequested();

Review Comment:
   can the members be changed to just raw longs/objects?  The accessors just 
need to be synchronized as well.
   
   Seems like this could be easier without autovalue since we don't need the 
accessors eather.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java:
##########
@@ -231,7 +236,7 @@ public void appendSpecificHtml(PrintWriter writer) {
             + "total budget received: %s,"
             + "last sent request: %s. ",
         workItemAssemblers.size(),
-        maxGetWorkBudget.get(),
+        budgetTracker.maxGetWorkBudget().get(),

Review Comment:
   could move html generation into budgettracker and not need all the 
accessors.  If we change how the tracker works in the future we might want to 
show more too.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java:
##########
@@ -41,6 +41,14 @@
 public abstract class WindmillEndpoints {
   private static final Logger LOG = 
LoggerFactory.getLogger(WindmillEndpoints.class);
 
+  public static WindmillEndpoints none() {
+    return WindmillEndpoints.builder()
+        .setVersion(Long.MAX_VALUE)

Review Comment:
   min seems safer. Otherwise if somehow none() was observed the logic to 
ensure version is increasing mean's we'd never process another endpoint set



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java:
##########
@@ -280,29 +300,30 @@ private synchronized void 
consumeWindmillWorkerEndpoints(WindmillEndpoints newWi
   }
 
   /** Close the streams that are no longer valid asynchronously. */
-  @SuppressWarnings("FutureReturnValueIgnored")
-  private void closeStaleStreams(WindmillEndpoints newWindmillEndpoints) {
+  private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) {
     StreamingEngineBackends currentBackends = backends.get();
-    ImmutableMap<Endpoint, WindmillStreamSender> currentWindmillStreams =
-        currentBackends.windmillStreams();
-    currentWindmillStreams.entrySet().stream()
+    currentBackends.windmillStreams().entrySet().stream()
         .filter(
             connectionAndStream ->
                 
!newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey()))
         .forEach(
-            entry ->
-                CompletableFuture.runAsync(
-                    () -> closeStreamSender(entry.getKey(), entry.getValue()),
-                    windmillStreamManager));
+            entry -> {
+              CompletableFuture<Void> ignored =
+                  CompletableFuture.runAsync(
+                      () -> closeStreamSender(entry.getKey(), 
entry.getValue()),
+                      windmillStreamManager);
+            });
 
     Set<Endpoint> newGlobalDataEndpoints =
         new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values());
     currentBackends.globalDataStreams().values().stream()
         .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint()))
         .forEach(
-            sender ->
-                CompletableFuture.runAsync(
-                    () -> closeStreamSender(sender.endpoint(), sender), 
windmillStreamManager));
+            sender -> {
+              CompletableFuture<Void> ignored =

Review Comment:
   ditto



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java:
##########
@@ -192,17 +82,17 @@ public void 
testDistributeBudget_distributesBudgetEvenlyIfPossible() {
     streams.forEach(

Review Comment:
   just skip the math in the test and inline the right values?
   The math is just copying what we have in the impl, if there is some bug in 
the impl hard coding the values at least is a sanity check.
   
   ditto below.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+  private static final Windmill.JobHeader TEST_JOB_HEADER =
+      Windmill.JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcDirectGetWorkStreamTest";
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+  private ManagedChannel inProcessChannel;
+  private GrpcDirectGetWorkStream stream;
+
+  private static Windmill.StreamingGetWorkRequestExtension 
extension(GetWorkBudget budget) {
+    return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+        .setMaxItems(budget.items())
+        .setMaxBytes(budget.bytes())
+        .build();
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+    checkNotNull(stream).shutdown();
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+    return createGetWorkStream(testStub, initialGetWorkBudget, new 
ThrottleTimer());
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      WorkItemScheduler workItemScheduler) {
+    return createGetWorkStream(
+        testStub, initialGetWorkBudget, new ThrottleTimer(), 
workItemScheduler);
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      ThrottleTimer throttleTimer) {
+    return createGetWorkStream(
+        testStub,
+        initialGetWorkBudget,
+        throttleTimer,
+        (workItem, watermarks, processingContext, getWorkStreamLatencies) -> 
{});
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      ThrottleTimer throttleTimer,
+      WorkItemScheduler workItemScheduler) {
+    serviceRegistry.addService(testStub);
+    return (GrpcDirectGetWorkStream)
+        GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+            .build()
+            .createDirectGetWorkStream(
+                WindmillConnection.builder()
+                    
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+                    .build(),
+                Windmill.GetWorkRequest.newBuilder()
+                    .setClientId(TEST_JOB_HEADER.getClientId())
+                    .setJobId(TEST_JOB_HEADER.getJobId())
+                    .setProjectId(TEST_JOB_HEADER.getProjectId())
+                    .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+                    .setMaxItems(initialGetWorkBudget.items())
+                    .setMaxBytes(initialGetWorkBudget.bytes())
+                    .build(),
+                throttleTimer,
+                mock(HeartbeatSender.class),
+                mock(GetDataClient.class),
+                mock(WorkCommitter.class),
+                workItemScheduler);
+  }
+
+  private Windmill.StreamingGetWorkResponseChunk 
createResponse(Windmill.WorkItem workItem) {
+    return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+        .setStreamId(1L)
+        .setComputationMetadata(
+            Windmill.ComputationWorkItemMetadata.newBuilder()
+                .setComputationId("compId")
+                .setInputDataWatermark(1L)
+                .setDependentRealtimeInputWatermark(1L)
+                .build())
+        .setSerializedWorkItem(workItem.toByteString())
+        .setRemainingBytesForWorkItem(0)
+        .build();
+  }
+
+  @Test
+  public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+      throws InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+    GetWorkBudget newBudget = 
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+    stream.setBudget(newBudget);
+
+    waitForRequests.await(5, TimeUnit.SECONDS);

Review Comment:
   assertTrue for awaits that should succeed. If they fail the rest will fail 
and maybe be confusing.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java:
##########
@@ -40,147 +38,39 @@ public class EvenGetWorkBudgetDistributorTest {
   @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
   @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
 
-  private static GetWorkBudgetDistributor 
createBudgetDistributor(GetWorkBudget activeWorkBudget) {
-    return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget);
-  }
+  private static GetWorkBudgetSpender createGetWorkBudgetOwner() {
+    // Lambdas are final and cannot be spied.
+    return spy(
+        new GetWorkBudgetSpender() {
 
-  private static GetWorkBudgetDistributor createBudgetDistributor(long 
activeWorkItemsAndBytes) {
-    return createBudgetDistributor(
-        GetWorkBudget.builder()
-            .setItems(activeWorkItemsAndBytes)
-            .setBytes(activeWorkItemsAndBytes)
-            .build());
+          @Override
+          public void setBudget(long items, long bytes) {}
+        });
   }
 
   @Test
   public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() {
-    createBudgetDistributor(1L)
+    GetWorkBudgetDistributors.distributeEvenly()
         .distributeBudget(
             ImmutableList.of(), 
GetWorkBudget.builder().setItems(10L).setBytes(10L).build());
   }
 
   @Test
   public void testDistributeBudget_doesNothingWithNoBudget() {
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()));
-    createBudgetDistributor(1L)
+    GetWorkBudgetSpender getWorkBudgetSpender = 
spy(createGetWorkBudgetOwner());

Review Comment:
   remove spy here? already done in the helper method
   
   ditto for below



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+  private static final Windmill.JobHeader TEST_JOB_HEADER =
+      Windmill.JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcDirectGetWorkStreamTest";
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+  private ManagedChannel inProcessChannel;
+  private GrpcDirectGetWorkStream stream;
+
+  private static Windmill.StreamingGetWorkRequestExtension 
extension(GetWorkBudget budget) {
+    return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+        .setMaxItems(budget.items())
+        .setMaxBytes(budget.bytes())
+        .build();
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+    checkNotNull(stream).shutdown();
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+    return createGetWorkStream(testStub, initialGetWorkBudget, new 
ThrottleTimer());
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(

Review Comment:
   ditto, tests can just use full 4 params and make a throttle timer



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+  private static final Windmill.JobHeader TEST_JOB_HEADER =
+      Windmill.JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcDirectGetWorkStreamTest";
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+  private ManagedChannel inProcessChannel;
+  private GrpcDirectGetWorkStream stream;
+
+  private static Windmill.StreamingGetWorkRequestExtension 
extension(GetWorkBudget budget) {
+    return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+        .setMaxItems(budget.items())
+        .setMaxBytes(budget.bytes())
+        .build();
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+    checkNotNull(stream).shutdown();
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+    return createGetWorkStream(testStub, initialGetWorkBudget, new 
ThrottleTimer());
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      WorkItemScheduler workItemScheduler) {
+    return createGetWorkStream(
+        testStub, initialGetWorkBudget, new ThrottleTimer(), 
workItemScheduler);
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      ThrottleTimer throttleTimer) {
+    return createGetWorkStream(
+        testStub,
+        initialGetWorkBudget,
+        throttleTimer,
+        (workItem, watermarks, processingContext, getWorkStreamLatencies) -> 
{});
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      ThrottleTimer throttleTimer,
+      WorkItemScheduler workItemScheduler) {
+    serviceRegistry.addService(testStub);
+    return (GrpcDirectGetWorkStream)
+        GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+            .build()
+            .createDirectGetWorkStream(
+                WindmillConnection.builder()
+                    
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+                    .build(),
+                Windmill.GetWorkRequest.newBuilder()
+                    .setClientId(TEST_JOB_HEADER.getClientId())
+                    .setJobId(TEST_JOB_HEADER.getJobId())
+                    .setProjectId(TEST_JOB_HEADER.getProjectId())
+                    .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+                    .setMaxItems(initialGetWorkBudget.items())
+                    .setMaxBytes(initialGetWorkBudget.bytes())
+                    .build(),
+                throttleTimer,
+                mock(HeartbeatSender.class),
+                mock(GetDataClient.class),
+                mock(WorkCommitter.class),
+                workItemScheduler);
+  }
+
+  private Windmill.StreamingGetWorkResponseChunk 
createResponse(Windmill.WorkItem workItem) {
+    return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+        .setStreamId(1L)
+        .setComputationMetadata(
+            Windmill.ComputationWorkItemMetadata.newBuilder()
+                .setComputationId("compId")
+                .setInputDataWatermark(1L)
+                .setDependentRealtimeInputWatermark(1L)
+                .build())
+        .setSerializedWorkItem(workItem.toByteString())
+        .setRemainingBytesForWorkItem(0)
+        .build();
+  }
+
+  @Test
+  public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+      throws InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+    GetWorkBudget newBudget = 
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+    stream.setBudget(newBudget);
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    // Header and extension.
+    assertThat(requestObserver.sent()).hasSize(expectedRequests);
+    assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension())
+        .isEqualTo(extension(newBudget));
+  }
+
+  @Test
+  public void testSetBudget_computesAndSendsCorrectExtension_existingBudget()
+      throws InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    GetWorkBudget initialBudget = 
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+    stream = createGetWorkStream(testStub, initialBudget);
+    GetWorkBudget newBudget = 
GetWorkBudget.builder().setItems(100).setBytes(100).build();
+    stream.setBudget(newBudget);
+    GetWorkBudget diff = newBudget.subtract(initialBudget);
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Header and extension.
+    assertThat(requests).hasSize(expectedRequests);
+    
assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff));
+  }
+
+  @Test
+  public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh()
+      throws InterruptedException {
+    int expectedRequests = 1;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream =
+        createGetWorkStream(
+            testStub,
+            
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+    
stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build());
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Assert that the extension was never sent, only the header.
+    assertThat(requests).hasSize(expectedRequests);
+    assertThat(Iterables.getOnlyElement(requests).getRequest())
+        .isInstanceOf(Windmill.GetWorkRequest.class);
+  }
+
+  @Test
+  public void testSetBudget_doesNothingIfStreamShutdown() throws 
InterruptedException {
+    int expectedRequests = 1;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+    stream.shutdown();
+    stream.setBudget(
+        
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Assert that the extension was never sent, only the header.
+    assertThat(requests).hasSize(1);
+    assertThat(Iterables.getOnlyElement(requests).getRequest())
+        .isInstanceOf(Windmill.GetWorkRequest.class);
+  }
+
+  @Test
+  public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws 
InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    GetWorkBudget initialBudget = 
GetWorkBudget.builder().setItems(1).setBytes(100).build();
+    Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+    stream =
+        createGetWorkStream(
+            testStub,
+            initialBudget,
+            (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+              scheduledWorkItems.add(work);
+            });
+    Windmill.WorkItem workItem =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+            .setWorkToken(1L)
+            .setShardingKey(1L)
+            .setCacheToken(1L)
+            .build();
+
+    testStub.injectResponse(createResponse(workItem));
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    assertThat(scheduledWorkItems).containsExactly(workItem);
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize();
+
+    assertThat(requests).hasSize(expectedRequests);
+    assertThat(Iterables.getLast(requests).getRequestExtension())
+        .isEqualTo(
+            extension(
+                GetWorkBudget.builder()
+                    .setItems(1)
+                    .setBytes(initialBudget.bytes() - inFlightBytes)
+                    .build()));
+  }
+
+  @Test
+  public void 
testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh()
+      throws InterruptedException {
+    int expectedRequests = 1;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+    stream =
+        createGetWorkStream(
+            testStub,
+            
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(),
+            (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+              scheduledWorkItems.add(work);
+            });
+    Windmill.WorkItem workItem =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+            .setWorkToken(1L)
+            .setShardingKey(1L)
+            .setCacheToken(1L)
+            .build();
+
+    testStub.injectResponse(createResponse(workItem));
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    assertThat(scheduledWorkItems).containsExactly(workItem);
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+
+    // Assert that the extension was never sent, only the header.
+    assertThat(requests).hasSize(expectedRequests);
+    assertThat(Iterables.getOnlyElement(requests).getRequest())
+        .isInstanceOf(Windmill.GetWorkRequest.class);
+  }
+
+  @Test
+  public void testOnResponse_stopsThrottling() {
+    ThrottleTimer throttleTimer = new ThrottleTimer();
+    TestGetWorkRequestObserver requestObserver =
+        new TestGetWorkRequestObserver(new CountDownLatch(1));
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream = createGetWorkStream(testStub, GetWorkBudget.noBudget(), 
throttleTimer);
+    stream.startThrottleTimer();

Review Comment:
   assertTrue(throttleTimer.throttled())



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+  private static final Windmill.JobHeader TEST_JOB_HEADER =
+      Windmill.JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcDirectGetWorkStreamTest";
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+  private ManagedChannel inProcessChannel;
+  private GrpcDirectGetWorkStream stream;
+
+  private static Windmill.StreamingGetWorkRequestExtension 
extension(GetWorkBudget budget) {
+    return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+        .setMaxItems(budget.items())
+        .setMaxBytes(budget.bytes())
+        .build();
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+    checkNotNull(stream).shutdown();
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+    return createGetWorkStream(testStub, initialGetWorkBudget, new 
ThrottleTimer());
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      WorkItemScheduler workItemScheduler) {
+    return createGetWorkStream(
+        testStub, initialGetWorkBudget, new ThrottleTimer(), 
workItemScheduler);
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      ThrottleTimer throttleTimer) {
+    return createGetWorkStream(
+        testStub,
+        initialGetWorkBudget,
+        throttleTimer,
+        (workItem, watermarks, processingContext, getWorkStreamLatencies) -> 
{});
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      ThrottleTimer throttleTimer,
+      WorkItemScheduler workItemScheduler) {
+    serviceRegistry.addService(testStub);
+    return (GrpcDirectGetWorkStream)
+        GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+            .build()
+            .createDirectGetWorkStream(
+                WindmillConnection.builder()
+                    
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+                    .build(),
+                Windmill.GetWorkRequest.newBuilder()
+                    .setClientId(TEST_JOB_HEADER.getClientId())
+                    .setJobId(TEST_JOB_HEADER.getJobId())
+                    .setProjectId(TEST_JOB_HEADER.getProjectId())
+                    .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+                    .setMaxItems(initialGetWorkBudget.items())
+                    .setMaxBytes(initialGetWorkBudget.bytes())
+                    .build(),
+                throttleTimer,
+                mock(HeartbeatSender.class),
+                mock(GetDataClient.class),
+                mock(WorkCommitter.class),
+                workItemScheduler);
+  }
+
+  private Windmill.StreamingGetWorkResponseChunk 
createResponse(Windmill.WorkItem workItem) {
+    return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+        .setStreamId(1L)
+        .setComputationMetadata(
+            Windmill.ComputationWorkItemMetadata.newBuilder()
+                .setComputationId("compId")
+                .setInputDataWatermark(1L)
+                .setDependentRealtimeInputWatermark(1L)
+                .build())
+        .setSerializedWorkItem(workItem.toByteString())
+        .setRemainingBytesForWorkItem(0)
+        .build();
+  }
+
+  @Test
+  public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+      throws InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+    GetWorkBudget newBudget = 
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+    stream.setBudget(newBudget);
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    // Header and extension.
+    assertThat(requestObserver.sent()).hasSize(expectedRequests);
+    assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension())
+        .isEqualTo(extension(newBudget));
+  }
+
+  @Test
+  public void testSetBudget_computesAndSendsCorrectExtension_existingBudget()
+      throws InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    GetWorkBudget initialBudget = 
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+    stream = createGetWorkStream(testStub, initialBudget);
+    GetWorkBudget newBudget = 
GetWorkBudget.builder().setItems(100).setBytes(100).build();
+    stream.setBudget(newBudget);
+    GetWorkBudget diff = newBudget.subtract(initialBudget);
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Header and extension.
+    assertThat(requests).hasSize(expectedRequests);
+    
assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff));
+  }
+
+  @Test
+  public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh()
+      throws InterruptedException {
+    int expectedRequests = 1;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream =
+        createGetWorkStream(
+            testStub,
+            
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+    
stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build());
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Assert that the extension was never sent, only the header.
+    assertThat(requests).hasSize(expectedRequests);
+    assertThat(Iterables.getOnlyElement(requests).getRequest())
+        .isInstanceOf(Windmill.GetWorkRequest.class);
+  }
+
+  @Test
+  public void testSetBudget_doesNothingIfStreamShutdown() throws 
InterruptedException {
+    int expectedRequests = 1;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+    stream.shutdown();
+    stream.setBudget(
+        
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Assert that the extension was never sent, only the header.
+    assertThat(requests).hasSize(1);
+    assertThat(Iterables.getOnlyElement(requests).getRequest())
+        .isInstanceOf(Windmill.GetWorkRequest.class);
+  }
+
+  @Test
+  public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws 
InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    GetWorkBudget initialBudget = 
GetWorkBudget.builder().setItems(1).setBytes(100).build();
+    Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+    stream =
+        createGetWorkStream(
+            testStub,
+            initialBudget,
+            (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+              scheduledWorkItems.add(work);
+            });
+    Windmill.WorkItem workItem =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+            .setWorkToken(1L)
+            .setShardingKey(1L)
+            .setCacheToken(1L)
+            .build();
+
+    testStub.injectResponse(createResponse(workItem));
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    assertThat(scheduledWorkItems).containsExactly(workItem);
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize();
+
+    assertThat(requests).hasSize(expectedRequests);
+    assertThat(Iterables.getLast(requests).getRequestExtension())
+        .isEqualTo(
+            extension(
+                GetWorkBudget.builder()
+                    .setItems(1)
+                    .setBytes(initialBudget.bytes() - inFlightBytes)
+                    .build()));
+  }
+
+  @Test
+  public void 
testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh()
+      throws InterruptedException {
+    int expectedRequests = 1;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+    stream =
+        createGetWorkStream(
+            testStub,
+            
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(),
+            (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+              scheduledWorkItems.add(work);
+            });
+    Windmill.WorkItem workItem =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+            .setWorkToken(1L)
+            .setShardingKey(1L)
+            .setCacheToken(1L)
+            .build();
+
+    testStub.injectResponse(createResponse(workItem));
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    assertThat(scheduledWorkItems).containsExactly(workItem);
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+
+    // Assert that the extension was never sent, only the header.
+    assertThat(requests).hasSize(expectedRequests);
+    assertThat(Iterables.getOnlyElement(requests).getRequest())
+        .isInstanceOf(Windmill.GetWorkRequest.class);
+  }
+
+  @Test
+  public void testOnResponse_stopsThrottling() {
+    ThrottleTimer throttleTimer = new ThrottleTimer();
+    TestGetWorkRequestObserver requestObserver =
+        new TestGetWorkRequestObserver(new CountDownLatch(1));
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream = createGetWorkStream(testStub, GetWorkBudget.noBudget(), 
throttleTimer);
+    stream.startThrottleTimer();
+    
testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance());

Review Comment:
   does this run inline? otherwise it seems like it coudl be racy below that 
unthrottling happens?



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+  private static final Windmill.JobHeader TEST_JOB_HEADER =
+      Windmill.JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcDirectGetWorkStreamTest";
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+  private ManagedChannel inProcessChannel;
+  private GrpcDirectGetWorkStream stream;
+
+  private static Windmill.StreamingGetWorkRequestExtension 
extension(GetWorkBudget budget) {
+    return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+        .setMaxItems(budget.items())
+        .setMaxBytes(budget.bytes())
+        .build();
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+    checkNotNull(stream).shutdown();
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+    return createGetWorkStream(testStub, initialGetWorkBudget, new 
ThrottleTimer());
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      WorkItemScheduler workItemScheduler) {
+    return createGetWorkStream(
+        testStub, initialGetWorkBudget, new ThrottleTimer(), 
workItemScheduler);
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      ThrottleTimer throttleTimer) {
+    return createGetWorkStream(
+        testStub,
+        initialGetWorkBudget,
+        throttleTimer,
+        (workItem, watermarks, processingContext, getWorkStreamLatencies) -> 
{});
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      ThrottleTimer throttleTimer,
+      WorkItemScheduler workItemScheduler) {
+    serviceRegistry.addService(testStub);
+    return (GrpcDirectGetWorkStream)
+        GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+            .build()
+            .createDirectGetWorkStream(
+                WindmillConnection.builder()
+                    
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+                    .build(),
+                Windmill.GetWorkRequest.newBuilder()
+                    .setClientId(TEST_JOB_HEADER.getClientId())
+                    .setJobId(TEST_JOB_HEADER.getJobId())
+                    .setProjectId(TEST_JOB_HEADER.getProjectId())
+                    .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+                    .setMaxItems(initialGetWorkBudget.items())
+                    .setMaxBytes(initialGetWorkBudget.bytes())
+                    .build(),
+                throttleTimer,
+                mock(HeartbeatSender.class),
+                mock(GetDataClient.class),
+                mock(WorkCommitter.class),
+                workItemScheduler);
+  }
+
+  private Windmill.StreamingGetWorkResponseChunk 
createResponse(Windmill.WorkItem workItem) {
+    return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+        .setStreamId(1L)
+        .setComputationMetadata(
+            Windmill.ComputationWorkItemMetadata.newBuilder()
+                .setComputationId("compId")
+                .setInputDataWatermark(1L)
+                .setDependentRealtimeInputWatermark(1L)
+                .build())
+        .setSerializedWorkItem(workItem.toByteString())
+        .setRemainingBytesForWorkItem(0)
+        .build();
+  }
+
+  @Test
+  public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+      throws InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+    GetWorkBudget newBudget = 
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+    stream.setBudget(newBudget);
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    // Header and extension.
+    assertThat(requestObserver.sent()).hasSize(expectedRequests);
+    assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension())
+        .isEqualTo(extension(newBudget));
+  }
+
+  @Test
+  public void testSetBudget_computesAndSendsCorrectExtension_existingBudget()
+      throws InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    GetWorkBudget initialBudget = 
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+    stream = createGetWorkStream(testStub, initialBudget);
+    GetWorkBudget newBudget = 
GetWorkBudget.builder().setItems(100).setBytes(100).build();
+    stream.setBudget(newBudget);
+    GetWorkBudget diff = newBudget.subtract(initialBudget);
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Header and extension.
+    assertThat(requests).hasSize(expectedRequests);
+    
assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff));
+  }
+
+  @Test
+  public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh()
+      throws InterruptedException {
+    int expectedRequests = 1;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream =
+        createGetWorkStream(
+            testStub,
+            
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+    
stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build());
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Assert that the extension was never sent, only the header.
+    assertThat(requests).hasSize(expectedRequests);
+    assertThat(Iterables.getOnlyElement(requests).getRequest())
+        .isInstanceOf(Windmill.GetWorkRequest.class);
+  }
+
+  @Test
+  public void testSetBudget_doesNothingIfStreamShutdown() throws 
InterruptedException {
+    int expectedRequests = 1;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+    stream.shutdown();
+    stream.setBudget(
+        
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Assert that the extension was never sent, only the header.
+    assertThat(requests).hasSize(1);
+    assertThat(Iterables.getOnlyElement(requests).getRequest())
+        .isInstanceOf(Windmill.GetWorkRequest.class);
+  }
+
+  @Test
+  public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws 
InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    GetWorkBudget initialBudget = 
GetWorkBudget.builder().setItems(1).setBytes(100).build();
+    Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+    stream =
+        createGetWorkStream(
+            testStub,
+            initialBudget,
+            (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+              scheduledWorkItems.add(work);
+            });
+    Windmill.WorkItem workItem =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+            .setWorkToken(1L)
+            .setShardingKey(1L)
+            .setCacheToken(1L)
+            .build();
+
+    testStub.injectResponse(createResponse(workItem));
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    assertThat(scheduledWorkItems).containsExactly(workItem);
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize();
+
+    assertThat(requests).hasSize(expectedRequests);
+    assertThat(Iterables.getLast(requests).getRequestExtension())
+        .isEqualTo(
+            extension(
+                GetWorkBudget.builder()
+                    .setItems(1)
+                    .setBytes(initialBudget.bytes() - inFlightBytes)
+                    .build()));
+  }
+
+  @Test
+  public void 
testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh()
+      throws InterruptedException {
+    int expectedRequests = 1;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+    stream =
+        createGetWorkStream(
+            testStub,
+            
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(),
+            (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+              scheduledWorkItems.add(work);
+            });
+    Windmill.WorkItem workItem =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+            .setWorkToken(1L)
+            .setShardingKey(1L)
+            .setCacheToken(1L)
+            .build();
+
+    testStub.injectResponse(createResponse(workItem));
+
+    waitForRequests.await(5, TimeUnit.SECONDS);
+
+    assertThat(scheduledWorkItems).containsExactly(workItem);
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+
+    // Assert that the extension was never sent, only the header.
+    assertThat(requests).hasSize(expectedRequests);
+    assertThat(Iterables.getOnlyElement(requests).getRequest())
+        .isInstanceOf(Windmill.GetWorkRequest.class);
+  }
+
+  @Test
+  public void testOnResponse_stopsThrottling() {
+    ThrottleTimer throttleTimer = new ThrottleTimer();
+    TestGetWorkRequestObserver requestObserver =
+        new TestGetWorkRequestObserver(new CountDownLatch(1));
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream = createGetWorkStream(testStub, GetWorkBudget.noBudget(), 
throttleTimer);
+    stream.startThrottleTimer();
+    
testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance());
+    assertFalse(throttleTimer.throttled());
+  }
+
+  private static class GetWorkStreamTestStub
+      extends 
CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase {
+    private final TestGetWorkRequestObserver requestObserver;

Review Comment:
   might as well mark volatile to prevent races



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@beam.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to