scwhittle commented on code in PR #32775: URL: https://github.com/apache/beam/pull/32775#discussion_r1804316412
########## runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java: ########## @@ -280,29 +300,30 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi } /** Close the streams that are no longer valid asynchronously. */ - @SuppressWarnings("FutureReturnValueIgnored") - private void closeStaleStreams(WindmillEndpoints newWindmillEndpoints) { + private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { StreamingEngineBackends currentBackends = backends.get(); - ImmutableMap<Endpoint, WindmillStreamSender> currentWindmillStreams = - currentBackends.windmillStreams(); - currentWindmillStreams.entrySet().stream() + currentBackends.windmillStreams().entrySet().stream() .filter( connectionAndStream -> !newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey())) .forEach( - entry -> - CompletableFuture.runAsync( - () -> closeStreamSender(entry.getKey(), entry.getValue()), - windmillStreamManager)); + entry -> { + CompletableFuture<Void> ignored = Review Comment: is this any different than just executing directly? if not it seems simpler to avoid the future. windmillStreamManager.execute( () -> closeStreamSender(entry.getKey(), entry.getValue())) ########## runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java: ########## @@ -445,10 +400,14 @@ private void waitForBudgetDistribution() throws InterruptedException { getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); Review Comment: seems like this should return value so tests can assert if they are awaiting? ########## runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java: ########## @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + checkNotNull(stream).shutdown(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( Review Comment: This is a lot of variants of the create can we get rid of this one? easy for test to call new ThrottleTimer if they don't care. ########## runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java: ########## @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + checkNotNull(stream).shutdown(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) { + return createGetWorkStream(testStub, initialGetWorkBudget, new ThrottleTimer()); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + WorkItemScheduler workItemScheduler) { + return createGetWorkStream( + testStub, initialGetWorkBudget, new ThrottleTimer(), workItemScheduler); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer) { + return createGetWorkStream( + testStub, + initialGetWorkBudget, + throttleTimer, + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer, + WorkItemScheduler workItemScheduler) { + serviceRegistry.addService(testStub); + return (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + } + + private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { + return Windmill.StreamingGetWorkResponseChunk.newBuilder() + .setStreamId(1L) + .setComputationMetadata( + Windmill.ComputationWorkItemMetadata.newBuilder() + .setComputationId("compId") + .setInputDataWatermark(1L) + .setDependentRealtimeInputWatermark(1L) + .build()) + .setSerializedWorkItem(workItem.toByteString()) + .setRemainingBytesForWorkItem(0) + .build(); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget()); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream.setBudget(newBudget); + + waitForRequests.await(5, TimeUnit.SECONDS); + + // Header and extension. + assertThat(requestObserver.sent()).hasSize(expectedRequests); + assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension()) Review Comment: woudl be better to verify the header as well. ditto for others ########## runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java: ########## @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + checkNotNull(stream).shutdown(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) { + return createGetWorkStream(testStub, initialGetWorkBudget, new ThrottleTimer()); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + WorkItemScheduler workItemScheduler) { + return createGetWorkStream( + testStub, initialGetWorkBudget, new ThrottleTimer(), workItemScheduler); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer) { + return createGetWorkStream( + testStub, + initialGetWorkBudget, + throttleTimer, + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer, + WorkItemScheduler workItemScheduler) { + serviceRegistry.addService(testStub); + return (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + } + + private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { + return Windmill.StreamingGetWorkResponseChunk.newBuilder() + .setStreamId(1L) + .setComputationMetadata( + Windmill.ComputationWorkItemMetadata.newBuilder() + .setComputationId("compId") + .setInputDataWatermark(1L) + .setDependentRealtimeInputWatermark(1L) + .build()) + .setSerializedWorkItem(workItem.toByteString()) + .setRemainingBytesForWorkItem(0) + .build(); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget()); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream.setBudget(newBudget); + + waitForRequests.await(5, TimeUnit.SECONDS); + + // Header and extension. + assertThat(requestObserver.sent()).hasSize(expectedRequests); + assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension()) + .isEqualTo(extension(newBudget)); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_existingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream = createGetWorkStream(testStub, initialBudget); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(100).setBytes(100).build(); + stream.setBudget(newBudget); + GetWorkBudget diff = newBudget.subtract(initialBudget); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + // Header and extension. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff)); + } + + @Test + public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build()); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getOnlyElement(requests).getRequest()) Review Comment: this maybe will never succeed since getRequest will return the default object? Or perhaps it throws? in either case you could just the type of the one-of via assertTrue(...hasRequest()) or something similar. But it could be better to have a matcher that you can use here and other tests that lets you verify the budget within the initial request. ditto for other instanceof checks ########## runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java: ########## @@ -40,147 +38,39 @@ public class EvenGetWorkBudgetDistributorTest { @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - private static GetWorkBudgetDistributor createBudgetDistributor(GetWorkBudget activeWorkBudget) { - return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget); - } + private static GetWorkBudgetSpender createGetWorkBudgetOwner() { + // Lambdas are final and cannot be spied. + return spy( + new GetWorkBudgetSpender() { - private static GetWorkBudgetDistributor createBudgetDistributor(long activeWorkItemsAndBytes) { - return createBudgetDistributor( - GetWorkBudget.builder() - .setItems(activeWorkItemsAndBytes) - .setBytes(activeWorkItemsAndBytes) - .build()); + @Override + public void setBudget(long items, long bytes) {} + }); } @Test public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() { - createBudgetDistributor(1L) + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.of(), GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); } @Test public void testDistributeBudget_doesNothingWithNoBudget() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget())); - createBudgetDistributor(1L) + GetWorkBudgetSpender getWorkBudgetSpender = spy(createGetWorkBudgetOwner()); + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget(ImmutableList.of(getWorkBudgetSpender), GetWorkBudget.noBudget()); verifyNoInteractions(getWorkBudgetSpender); } - @Test - public void testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighNoActiveWork() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy( - createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget.builder().setItems(10L).setBytes(10L).build())); - createBudgetDistributor(0L) - .distributeBudget( - ImmutableList.of(getWorkBudgetSpender), - GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); - - verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong()); - } - - @Test - public void - testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighWithActiveWork() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy( - createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget.builder().setItems(5L).setBytes(5L).build())); - createBudgetDistributor(10L) - .distributeBudget( - ImmutableList.of(getWorkBudgetSpender), - GetWorkBudget.builder().setItems(20L).setBytes(20L).build()); - - verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong()); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithNoActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(0L) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - long activeWorkItemsAndBytes = 2L; - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(activeWorkItemsAndBytes) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq( - totalGetWorkBudget.items() - - streamRemainingBudget.items() - - activeWorkItemsAndBytes), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowNoActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(0L) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowWithActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - long activeWorkItemsAndBytes = 2L; - - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(activeWorkItemsAndBytes) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq( - totalGetWorkBudget.bytes() - - streamRemainingBudget.bytes() - - activeWorkItemsAndBytes)); - } - @Test public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { long totalItemsAndBytes = 10L; Review Comment: would be better to have different items and bytes values to confirm distributor doesn't mix them up internally ########## runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java: ########## @@ -61,38 +43,15 @@ public <T extends GetWorkBudgetSpender> void distributeBudget( return; } - Map<T, GetWorkBudget> desiredBudgets = computeDesiredBudgets(budgetOwners, getWorkBudget); - - for (Entry<T, GetWorkBudget> streamAndDesiredBudget : desiredBudgets.entrySet()) { - GetWorkBudgetSpender getWorkBudgetSpender = streamAndDesiredBudget.getKey(); - GetWorkBudget desired = streamAndDesiredBudget.getValue(); - GetWorkBudget remaining = getWorkBudgetSpender.remainingBudget(); - if (isBelowFiftyPercentOfTarget(remaining, desired)) { - GetWorkBudget adjustment = desired.subtract(remaining); - getWorkBudgetSpender.adjustBudget(adjustment); - } - } + GetWorkBudget budgetPerStream = computeDesiredBudgets(budgetSpenders, getWorkBudget); + budgetSpenders.forEach(getWorkBudgetSpender -> getWorkBudgetSpender.setBudget(budgetPerStream)); } - private <T extends GetWorkBudgetSpender> ImmutableMap<T, GetWorkBudget> computeDesiredBudgets( + private <T extends GetWorkBudgetSpender> GetWorkBudget computeDesiredBudgets( Review Comment: nit: maybe name computeDesiredPerStreamBudget? or just inline? budgets makes it sound like it is computing multiple. ########## runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java: ########## @@ -267,25 +277,100 @@ protected void startThrottleTimer() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { - GetWorkBudget adjustment = - nextBudgetAdjustment - // Get the current value, and reset the nextBudgetAdjustment. This will be set again - // when adjustBudget is called. - .getAndUpdate(unused -> GetWorkBudget.noBudget()) - .apply(itemsDelta, bytesDelta); - sendRequestExtension(adjustment); + public void setBudget(long newItems, long newBytes) { + GetWorkBudget currentMaxGetWorkBudget = + maxGetWorkBudget.updateAndGet( + ignored -> GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build()); + GetWorkBudget extension = budgetTracker.computeBudgetExtension(currentMaxGetWorkBudget); + maybeSendRequestExtension(extension); } - @Override - public GetWorkBudget remainingBudget() { - // Snapshot the current budgets. - GetWorkBudget currentPendingResponseBudget = pendingResponseBudget.get(); - GetWorkBudget currentNextBudgetAdjustment = nextBudgetAdjustment.get(); - GetWorkBudget currentInflightBudget = inFlightBudget.get(); - - return currentPendingResponseBudget - .apply(currentNextBudgetAdjustment) - .apply(currentInflightBudget); + /** + * Tracks sent and received GetWorkBudget and uses this information to generate request + * extensions. + */ + @AutoValue + abstract static class GetWorkBudgetTracker { + + private static GetWorkBudgetTracker create() { + return new AutoValue_GrpcDirectGetWorkStream_GetWorkBudgetTracker( + new AtomicLong(), new AtomicLong(), new AtomicLong(), new AtomicLong()); + } + + abstract AtomicLong itemsRequested(); Review Comment: can the members be changed to just raw longs/objects? The accessors just need to be synchronized as well. Seems like this could be easier without autovalue since we don't need the accessors eather. ########## runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java: ########## @@ -231,7 +236,7 @@ public void appendSpecificHtml(PrintWriter writer) { + "total budget received: %s," + "last sent request: %s. ", workItemAssemblers.size(), - maxGetWorkBudget.get(), + budgetTracker.maxGetWorkBudget().get(), Review Comment: could move html generation into budgettracker and not need all the accessors. If we change how the tracker works in the future we might want to show more too. ########## runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java: ########## @@ -41,6 +41,14 @@ public abstract class WindmillEndpoints { private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class); + public static WindmillEndpoints none() { + return WindmillEndpoints.builder() + .setVersion(Long.MAX_VALUE) Review Comment: min seems safer. Otherwise if somehow none() was observed the logic to ensure version is increasing mean's we'd never process another endpoint set ########## runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java: ########## @@ -280,29 +300,30 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi } /** Close the streams that are no longer valid asynchronously. */ - @SuppressWarnings("FutureReturnValueIgnored") - private void closeStaleStreams(WindmillEndpoints newWindmillEndpoints) { + private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { StreamingEngineBackends currentBackends = backends.get(); - ImmutableMap<Endpoint, WindmillStreamSender> currentWindmillStreams = - currentBackends.windmillStreams(); - currentWindmillStreams.entrySet().stream() + currentBackends.windmillStreams().entrySet().stream() .filter( connectionAndStream -> !newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey())) .forEach( - entry -> - CompletableFuture.runAsync( - () -> closeStreamSender(entry.getKey(), entry.getValue()), - windmillStreamManager)); + entry -> { + CompletableFuture<Void> ignored = + CompletableFuture.runAsync( + () -> closeStreamSender(entry.getKey(), entry.getValue()), + windmillStreamManager); + }); Set<Endpoint> newGlobalDataEndpoints = new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values()); currentBackends.globalDataStreams().values().stream() .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint())) .forEach( - sender -> - CompletableFuture.runAsync( - () -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager)); + sender -> { + CompletableFuture<Void> ignored = Review Comment: ditto ########## runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java: ########## @@ -192,17 +82,17 @@ public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { streams.forEach( Review Comment: just skip the math in the test and inline the right values? The math is just copying what we have in the impl, if there is some bug in the impl hard coding the values at least is a sanity check. ditto below. ########## runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java: ########## @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + checkNotNull(stream).shutdown(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) { + return createGetWorkStream(testStub, initialGetWorkBudget, new ThrottleTimer()); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + WorkItemScheduler workItemScheduler) { + return createGetWorkStream( + testStub, initialGetWorkBudget, new ThrottleTimer(), workItemScheduler); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer) { + return createGetWorkStream( + testStub, + initialGetWorkBudget, + throttleTimer, + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer, + WorkItemScheduler workItemScheduler) { + serviceRegistry.addService(testStub); + return (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + } + + private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { + return Windmill.StreamingGetWorkResponseChunk.newBuilder() + .setStreamId(1L) + .setComputationMetadata( + Windmill.ComputationWorkItemMetadata.newBuilder() + .setComputationId("compId") + .setInputDataWatermark(1L) + .setDependentRealtimeInputWatermark(1L) + .build()) + .setSerializedWorkItem(workItem.toByteString()) + .setRemainingBytesForWorkItem(0) + .build(); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget()); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream.setBudget(newBudget); + + waitForRequests.await(5, TimeUnit.SECONDS); Review Comment: assertTrue for awaits that should succeed. If they fail the rest will fail and maybe be confusing. ########## runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java: ########## @@ -40,147 +38,39 @@ public class EvenGetWorkBudgetDistributorTest { @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - private static GetWorkBudgetDistributor createBudgetDistributor(GetWorkBudget activeWorkBudget) { - return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget); - } + private static GetWorkBudgetSpender createGetWorkBudgetOwner() { + // Lambdas are final and cannot be spied. + return spy( + new GetWorkBudgetSpender() { - private static GetWorkBudgetDistributor createBudgetDistributor(long activeWorkItemsAndBytes) { - return createBudgetDistributor( - GetWorkBudget.builder() - .setItems(activeWorkItemsAndBytes) - .setBytes(activeWorkItemsAndBytes) - .build()); + @Override + public void setBudget(long items, long bytes) {} + }); } @Test public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() { - createBudgetDistributor(1L) + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.of(), GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); } @Test public void testDistributeBudget_doesNothingWithNoBudget() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget())); - createBudgetDistributor(1L) + GetWorkBudgetSpender getWorkBudgetSpender = spy(createGetWorkBudgetOwner()); Review Comment: remove spy here? already done in the helper method ditto for below ########## runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java: ########## @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + checkNotNull(stream).shutdown(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) { + return createGetWorkStream(testStub, initialGetWorkBudget, new ThrottleTimer()); + } + + private GrpcDirectGetWorkStream createGetWorkStream( Review Comment: ditto, tests can just use full 4 params and make a throttle timer ########## runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java: ########## @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + checkNotNull(stream).shutdown(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) { + return createGetWorkStream(testStub, initialGetWorkBudget, new ThrottleTimer()); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + WorkItemScheduler workItemScheduler) { + return createGetWorkStream( + testStub, initialGetWorkBudget, new ThrottleTimer(), workItemScheduler); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer) { + return createGetWorkStream( + testStub, + initialGetWorkBudget, + throttleTimer, + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer, + WorkItemScheduler workItemScheduler) { + serviceRegistry.addService(testStub); + return (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + } + + private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { + return Windmill.StreamingGetWorkResponseChunk.newBuilder() + .setStreamId(1L) + .setComputationMetadata( + Windmill.ComputationWorkItemMetadata.newBuilder() + .setComputationId("compId") + .setInputDataWatermark(1L) + .setDependentRealtimeInputWatermark(1L) + .build()) + .setSerializedWorkItem(workItem.toByteString()) + .setRemainingBytesForWorkItem(0) + .build(); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget()); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream.setBudget(newBudget); + + waitForRequests.await(5, TimeUnit.SECONDS); + + // Header and extension. + assertThat(requestObserver.sent()).hasSize(expectedRequests); + assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension()) + .isEqualTo(extension(newBudget)); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_existingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream = createGetWorkStream(testStub, initialBudget); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(100).setBytes(100).build(); + stream.setBudget(newBudget); + GetWorkBudget diff = newBudget.subtract(initialBudget); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + // Header and extension. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff)); + } + + @Test + public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build()); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getOnlyElement(requests).getRequest()) + .isInstanceOf(Windmill.GetWorkRequest.class); + } + + @Test + public void testSetBudget_doesNothingIfStreamShutdown() throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget()); + stream.shutdown(); + stream.setBudget( + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(1); + assertThat(Iterables.getOnlyElement(requests).getRequest()) + .isInstanceOf(Windmill.GetWorkRequest.class); + } + + @Test + public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(1).setBytes(100).build(); + Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>(); + stream = + createGetWorkStream( + testStub, + initialBudget, + (work, watermarks, processingContext, getWorkStreamLatencies) -> { + scheduledWorkItems.add(work); + }); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + waitForRequests.await(5, TimeUnit.SECONDS); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize(); + + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getLast(requests).getRequestExtension()) + .isEqualTo( + extension( + GetWorkBudget.builder() + .setItems(1) + .setBytes(initialBudget.bytes() - inFlightBytes) + .build())); + } + + @Test + public void testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>(); + stream = + createGetWorkStream( + testStub, + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> { + scheduledWorkItems.add(work); + }); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + waitForRequests.await(5, TimeUnit.SECONDS); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getOnlyElement(requests).getRequest()) + .isInstanceOf(Windmill.GetWorkRequest.class); + } + + @Test + public void testOnResponse_stopsThrottling() { + ThrottleTimer throttleTimer = new ThrottleTimer(); + TestGetWorkRequestObserver requestObserver = + new TestGetWorkRequestObserver(new CountDownLatch(1)); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget(), throttleTimer); + stream.startThrottleTimer(); Review Comment: assertTrue(throttleTimer.throttled()) ########## runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java: ########## @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + checkNotNull(stream).shutdown(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) { + return createGetWorkStream(testStub, initialGetWorkBudget, new ThrottleTimer()); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + WorkItemScheduler workItemScheduler) { + return createGetWorkStream( + testStub, initialGetWorkBudget, new ThrottleTimer(), workItemScheduler); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer) { + return createGetWorkStream( + testStub, + initialGetWorkBudget, + throttleTimer, + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer, + WorkItemScheduler workItemScheduler) { + serviceRegistry.addService(testStub); + return (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + } + + private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { + return Windmill.StreamingGetWorkResponseChunk.newBuilder() + .setStreamId(1L) + .setComputationMetadata( + Windmill.ComputationWorkItemMetadata.newBuilder() + .setComputationId("compId") + .setInputDataWatermark(1L) + .setDependentRealtimeInputWatermark(1L) + .build()) + .setSerializedWorkItem(workItem.toByteString()) + .setRemainingBytesForWorkItem(0) + .build(); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget()); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream.setBudget(newBudget); + + waitForRequests.await(5, TimeUnit.SECONDS); + + // Header and extension. + assertThat(requestObserver.sent()).hasSize(expectedRequests); + assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension()) + .isEqualTo(extension(newBudget)); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_existingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream = createGetWorkStream(testStub, initialBudget); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(100).setBytes(100).build(); + stream.setBudget(newBudget); + GetWorkBudget diff = newBudget.subtract(initialBudget); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + // Header and extension. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff)); + } + + @Test + public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build()); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getOnlyElement(requests).getRequest()) + .isInstanceOf(Windmill.GetWorkRequest.class); + } + + @Test + public void testSetBudget_doesNothingIfStreamShutdown() throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget()); + stream.shutdown(); + stream.setBudget( + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(1); + assertThat(Iterables.getOnlyElement(requests).getRequest()) + .isInstanceOf(Windmill.GetWorkRequest.class); + } + + @Test + public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(1).setBytes(100).build(); + Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>(); + stream = + createGetWorkStream( + testStub, + initialBudget, + (work, watermarks, processingContext, getWorkStreamLatencies) -> { + scheduledWorkItems.add(work); + }); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + waitForRequests.await(5, TimeUnit.SECONDS); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize(); + + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getLast(requests).getRequestExtension()) + .isEqualTo( + extension( + GetWorkBudget.builder() + .setItems(1) + .setBytes(initialBudget.bytes() - inFlightBytes) + .build())); + } + + @Test + public void testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>(); + stream = + createGetWorkStream( + testStub, + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> { + scheduledWorkItems.add(work); + }); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + waitForRequests.await(5, TimeUnit.SECONDS); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getOnlyElement(requests).getRequest()) + .isInstanceOf(Windmill.GetWorkRequest.class); + } + + @Test + public void testOnResponse_stopsThrottling() { + ThrottleTimer throttleTimer = new ThrottleTimer(); + TestGetWorkRequestObserver requestObserver = + new TestGetWorkRequestObserver(new CountDownLatch(1)); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget(), throttleTimer); + stream.startThrottleTimer(); + testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance()); Review Comment: does this run inline? otherwise it seems like it coudl be racy below that unthrottling happens? ########## runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java: ########## @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + checkNotNull(stream).shutdown(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) { + return createGetWorkStream(testStub, initialGetWorkBudget, new ThrottleTimer()); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + WorkItemScheduler workItemScheduler) { + return createGetWorkStream( + testStub, initialGetWorkBudget, new ThrottleTimer(), workItemScheduler); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer) { + return createGetWorkStream( + testStub, + initialGetWorkBudget, + throttleTimer, + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer, + WorkItemScheduler workItemScheduler) { + serviceRegistry.addService(testStub); + return (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + } + + private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { + return Windmill.StreamingGetWorkResponseChunk.newBuilder() + .setStreamId(1L) + .setComputationMetadata( + Windmill.ComputationWorkItemMetadata.newBuilder() + .setComputationId("compId") + .setInputDataWatermark(1L) + .setDependentRealtimeInputWatermark(1L) + .build()) + .setSerializedWorkItem(workItem.toByteString()) + .setRemainingBytesForWorkItem(0) + .build(); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget()); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream.setBudget(newBudget); + + waitForRequests.await(5, TimeUnit.SECONDS); + + // Header and extension. + assertThat(requestObserver.sent()).hasSize(expectedRequests); + assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension()) + .isEqualTo(extension(newBudget)); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_existingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream = createGetWorkStream(testStub, initialBudget); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(100).setBytes(100).build(); + stream.setBudget(newBudget); + GetWorkBudget diff = newBudget.subtract(initialBudget); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + // Header and extension. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff)); + } + + @Test + public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build()); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getOnlyElement(requests).getRequest()) + .isInstanceOf(Windmill.GetWorkRequest.class); + } + + @Test + public void testSetBudget_doesNothingIfStreamShutdown() throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget()); + stream.shutdown(); + stream.setBudget( + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(1); + assertThat(Iterables.getOnlyElement(requests).getRequest()) + .isInstanceOf(Windmill.GetWorkRequest.class); + } + + @Test + public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(1).setBytes(100).build(); + Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>(); + stream = + createGetWorkStream( + testStub, + initialBudget, + (work, watermarks, processingContext, getWorkStreamLatencies) -> { + scheduledWorkItems.add(work); + }); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + waitForRequests.await(5, TimeUnit.SECONDS); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize(); + + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getLast(requests).getRequestExtension()) + .isEqualTo( + extension( + GetWorkBudget.builder() + .setItems(1) + .setBytes(initialBudget.bytes() - inFlightBytes) + .build())); + } + + @Test + public void testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>(); + stream = + createGetWorkStream( + testStub, + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> { + scheduledWorkItems.add(work); + }); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + waitForRequests.await(5, TimeUnit.SECONDS); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent(); + + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getOnlyElement(requests).getRequest()) + .isInstanceOf(Windmill.GetWorkRequest.class); + } + + @Test + public void testOnResponse_stopsThrottling() { + ThrottleTimer throttleTimer = new ThrottleTimer(); + TestGetWorkRequestObserver requestObserver = + new TestGetWorkRequestObserver(new CountDownLatch(1)); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget(), throttleTimer); + stream.startThrottleTimer(); + testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance()); + assertFalse(throttleTimer.throttled()); + } + + private static class GetWorkStreamTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestGetWorkRequestObserver requestObserver; Review Comment: might as well mark volatile to prevent races -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@beam.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org