This is an automated email from the ASF dual-hosted git repository. zhuzh pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 76ebeb93257369c136a9eeabd5bacb71a9699968 Author: Zhu Zhu <[email protected]> AuthorDate: Thu Jan 19 17:13:50 2023 +0800 [FLINK-15325][coordination] Ignores the input locations of a ConsumePartitionGroup if the corresponding ConsumerVertexGroup is too large This closes #21743. --- .../runtime/executiongraph/ExecutionGraph.java | 10 ++ .../runtime/executiongraph/ExecutionVertex.java | 2 - .../AvailableInputsLocationsRetriever.java | 12 +- .../DefaultPreferredLocationsRetriever.java | 27 +++- .../flink/runtime/scheduler/DefaultScheduler.java | 12 +- ...tionGraphToInputsLocationsRetrieverAdapter.java | 38 +++--- .../scheduler/InputsLocationsRetriever.java | 26 ++-- .../AvailableInputsLocationsRetrieverTest.java | 20 ++- .../DefaultPreferredLocationsRetrieverTest.java | 144 ++++++++++++++------- ...DefaultSyncPreferredLocationsRetrieverTest.java | 31 +++-- ...GraphToInputsLocationsRetrieverAdapterTest.java | 72 ++++++++--- .../scheduler/TestingInputsLocationsRetriever.java | 84 +++++++++--- .../adaptive/StateTrackingMockExecutionGraph.java | 7 + 13 files changed, 339 insertions(+), 146 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java index 2f90291e75c..cc1f5fa33c6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java @@ -33,6 +33,7 @@ import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor; import org.apache.flink.runtime.executiongraph.failover.flip1.ResultPartitionAvailabilityChecker; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration; @@ -124,6 +125,15 @@ public interface ExecutionGraph extends AccessExecutionGraph { Map<IntermediateDataSetID, IntermediateResult> getAllIntermediateResults(); + /** + * Gets the intermediate result partition by the given partition ID, or throw an exception if + * the partition is not found. + * + * @param id of the intermediate result partition + * @return intermediate result partition + */ + IntermediateResultPartition getResultPartitionOrThrow(final IntermediateResultPartitionID id); + /** * Merges all accumulator results from the tasks previously executed in the Executions. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java index 5ff589669b4..28f1d745e37 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java @@ -62,8 +62,6 @@ public class ExecutionVertex public static final long NUM_BYTES_UNKNOWN = -1; - public static final int MAX_DISTINCT_LOCATIONS_TO_CONSIDER = 8; - // -------------------------------------------------------------------------------------------- final ExecutionJobVertex jobVertex; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetriever.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetriever.java index 4008b1cef01..2c2a0f44741 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetriever.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetriever.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.scheduler; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; @@ -34,9 +35,16 @@ class AvailableInputsLocationsRetriever implements InputsLocationsRetriever { } @Override - public Collection<Collection<ExecutionVertexID>> getConsumedResultPartitionsProducers( + public Collection<ConsumedPartitionGroup> getConsumedPartitionGroups( ExecutionVertexID executionVertexId) { - return inputsLocationsRetriever.getConsumedResultPartitionsProducers(executionVertexId); + return inputsLocationsRetriever.getConsumedPartitionGroups(executionVertexId); + } + + @Override + public Collection<ExecutionVertexID> getProducersOfConsumedPartitionGroup( + ConsumedPartitionGroup consumedPartitionGroup) { + return inputsLocationsRetriever.getProducersOfConsumedPartitionGroup( + consumedPartitionGroup); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetriever.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetriever.java index 0f7fdeb1316..f3b71366682 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetriever.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetriever.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.scheduler; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.util.concurrent.FutureUtils; @@ -30,7 +31,6 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.CompletableFuture; -import static org.apache.flink.runtime.executiongraph.ExecutionVertex.MAX_DISTINCT_LOCATIONS_TO_CONSIDER; import static org.apache.flink.util.Preconditions.checkNotNull; /** @@ -39,6 +39,10 @@ import static org.apache.flink.util.Preconditions.checkNotNull; */ public class DefaultPreferredLocationsRetriever implements PreferredLocationsRetriever { + static final int MAX_DISTINCT_LOCATIONS_TO_CONSIDER = 8; + + static final int MAX_DISTINCT_CONSUMERS_TO_CONSIDER = 8; + private final StateLocationRetriever stateLocationRetriever; private final InputsLocationsRetriever inputsLocationsRetriever; @@ -84,11 +88,24 @@ public class DefaultPreferredLocationsRetriever implements PreferredLocationsRet CompletableFuture<Collection<TaskManagerLocation>> preferredLocations = CompletableFuture.completedFuture(Collections.emptyList()); - final Collection<Collection<ExecutionVertexID>> allProducers = - inputsLocationsRetriever.getConsumedResultPartitionsProducers(executionVertexId); - for (Collection<ExecutionVertexID> producers : allProducers) { + final Collection<ConsumedPartitionGroup> consumedPartitionGroups = + inputsLocationsRetriever.getConsumedPartitionGroups(executionVertexId); + for (ConsumedPartitionGroup consumedPartitionGroup : consumedPartitionGroups) { + // Ignore the location of a consumed partition group if it has too many distinct + // consumers compared to the consumed partition group size. This is to avoid tasks + // unevenly distributed on nodes when running batch jobs or running jobs in + // session/standalone mode. + if ((double) consumedPartitionGroup.getConsumerVertexGroup().size() + / consumedPartitionGroup.size() + > MAX_DISTINCT_CONSUMERS_TO_CONSIDER) { + continue; + } + final Collection<CompletableFuture<TaskManagerLocation>> locationsFutures = - getInputLocationFutures(producersToIgnore, producers); + getInputLocationFutures( + producersToIgnore, + inputsLocationsRetriever.getProducersOfConsumedPartitionGroup( + consumedPartitionGroup)); preferredLocations = combineLocations(preferredLocations, locationsFutures); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java index 2136479a2c7..ecd2a2467cb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java @@ -44,6 +44,7 @@ import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup; import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; import org.apache.flink.runtime.metrics.groups.JobManagerJobMetricGroup; import org.apache.flink.runtime.scheduler.exceptionhistory.FailureHandlingResultSnapshot; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.scheduler.strategy.SchedulingStrategy; import org.apache.flink.runtime.scheduler.strategy.SchedulingStrategyFactory; @@ -511,9 +512,16 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio } @Override - public Collection<Collection<ExecutionVertexID>> getConsumedResultPartitionsProducers( + public Collection<ConsumedPartitionGroup> getConsumedPartitionGroups( ExecutionVertexID executionVertexId) { - return inputsLocationsRetriever.getConsumedResultPartitionsProducers(executionVertexId); + return inputsLocationsRetriever.getConsumedPartitionGroups(executionVertexId); + } + + @Override + public Collection<ExecutionVertexID> getProducersOfConsumedPartitionGroup( + ConsumedPartitionGroup consumedPartitionGroup) { + return inputsLocationsRetriever.getProducersOfConsumedPartitionGroup( + consumedPartitionGroup); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapter.java index 0a6786c8ac6..35f3da6c868 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapter.java @@ -22,17 +22,15 @@ import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; -import org.apache.flink.runtime.executiongraph.InternalExecutionGraphAccessor; -import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; +import org.apache.flink.util.IterableUtils; -import java.util.ArrayList; import java.util.Collection; -import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; @@ -47,26 +45,22 @@ public class ExecutionGraphToInputsLocationsRetrieverAdapter implements InputsLo } @Override - public Collection<Collection<ExecutionVertexID>> getConsumedResultPartitionsProducers( + public Collection<ConsumedPartitionGroup> getConsumedPartitionGroups( ExecutionVertexID executionVertexId) { - ExecutionVertex ev = getExecutionVertex(executionVertexId); - - InternalExecutionGraphAccessor executionGraphAccessor = ev.getExecutionGraphAccessor(); + return getExecutionVertex(executionVertexId).getAllConsumedPartitionGroups(); + } - List<Collection<ExecutionVertexID>> resultPartitionProducers = - new ArrayList<>(ev.getNumberOfInputs()); - for (ConsumedPartitionGroup consumedPartitions : ev.getAllConsumedPartitionGroups()) { - List<ExecutionVertexID> producers = new ArrayList<>(consumedPartitions.size()); - for (IntermediateResultPartitionID consumedPartitionId : consumedPartitions) { - ExecutionVertex producer = - executionGraphAccessor - .getResultPartitionOrThrow(consumedPartitionId) - .getProducer(); - producers.add(producer.getID()); - } - resultPartitionProducers.add(producers); - } - return resultPartitionProducers; + @Override + public Collection<ExecutionVertexID> getProducersOfConsumedPartitionGroup( + ConsumedPartitionGroup consumedPartitionGroup) { + return IterableUtils.toStream(consumedPartitionGroup) + .map( + partition -> + executionGraph + .getResultPartitionOrThrow(partition) + .getProducer() + .getID()) + .collect(Collectors.toList()); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/InputsLocationsRetriever.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/InputsLocationsRetriever.java index ea143bb5e59..0c49f4bf5df 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/InputsLocationsRetriever.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/InputsLocationsRetriever.java @@ -18,7 +18,8 @@ package org.apache.flink.runtime.scheduler; -import org.apache.flink.runtime.executiongraph.Execution; +import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; @@ -26,22 +27,31 @@ import java.util.Collection; import java.util.Optional; import java.util.concurrent.CompletableFuture; -/** Component to retrieve the inputs locations of a {@link Execution}. */ +/** Component to retrieve the inputs locations of an {@link ExecutionVertex}. */ public interface InputsLocationsRetriever { /** - * Get the producers of the result partitions consumed by an execution. + * Get the consumed result partition groups of an execution vertex. * - * @param executionVertexId identifies the execution - * @return the producers of the result partitions group by job vertex id + * @param executionVertexId identifies the execution vertex + * @return the consumed result partition groups */ - Collection<Collection<ExecutionVertexID>> getConsumedResultPartitionsProducers( + Collection<ConsumedPartitionGroup> getConsumedPartitionGroups( ExecutionVertexID executionVertexId); /** - * Get the task manager location future for an execution. + * Get the producer execution vertices of a consumed result partition group. * - * @param executionVertexId identifying the execution + * @param consumedPartitionGroup the consumed result partition group + * @return the ids of producer execution vertices + */ + Collection<ExecutionVertexID> getProducersOfConsumedPartitionGroup( + ConsumedPartitionGroup consumedPartitionGroup); + + /** + * Get the task manager location future for an execution vertex. + * + * @param executionVertexId identifying the execution vertex * @return the task manager location future */ Optional<CompletableFuture<TaskManagerLocation>> getTaskManagerLocation( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetrieverTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetrieverTest.java index d2d5be5b18e..0ebcffdf570 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetrieverTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetrieverTest.java @@ -18,8 +18,11 @@ package org.apache.flink.runtime.scheduler; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.shaded.guava30.com.google.common.collect.Iterables; + import org.junit.jupiter.api.Test; import java.util.Collection; @@ -68,15 +71,20 @@ class AvailableInputsLocationsRetrieverTest { } @Test - void testConsumedResultPartitionsProducers() { + void testGetConsumedPartitionGroupAndProducers() { TestingInputsLocationsRetriever originalLocationRetriever = getOriginalLocationRetriever(); InputsLocationsRetriever availableInputsLocationsRetriever = new AvailableInputsLocationsRetriever(originalLocationRetriever); - Collection<Collection<ExecutionVertexID>> producers = - availableInputsLocationsRetriever.getConsumedResultPartitionsProducers(EV2); - assertThat(producers).hasSize(1); - Collection<ExecutionVertexID> resultProducers = producers.iterator().next(); - assertThat(resultProducers).containsExactly(EV1); + + ConsumedPartitionGroup consumedPartitionGroup = + Iterables.getOnlyElement( + (availableInputsLocationsRetriever.getConsumedPartitionGroups(EV2))); + assertThat(consumedPartitionGroup).hasSize(1); + + Collection<ExecutionVertexID> producers = + availableInputsLocationsRetriever.getProducersOfConsumedPartitionGroup( + consumedPartitionGroup); + assertThat(producers).containsExactly(EV1); } private static TestingInputsLocationsRetriever getOriginalLocationRetriever() { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetrieverTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetrieverTest.java index 64944dcea00..6fda52dc173 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetrieverTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetrieverTest.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.scheduler; -import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.taskmanager.LocalTaskManagerLocation; @@ -27,12 +26,18 @@ import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.junit.jupiter.api.Test; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import static org.apache.flink.runtime.scheduler.DefaultPreferredLocationsRetriever.MAX_DISTINCT_CONSUMERS_TO_CONSIDER; +import static org.apache.flink.runtime.scheduler.DefaultPreferredLocationsRetriever.MAX_DISTINCT_LOCATIONS_TO_CONSIDER; import static org.assertj.core.api.Assertions.assertThat; /** Tests {@link DefaultPreferredLocationsRetriever}. */ @@ -65,36 +70,38 @@ class DefaultPreferredLocationsRetrieverTest { } @Test - void testInputLocationsIgnoresEdgeOfTooManyLocations() { - final TestingInputsLocationsRetriever.Builder locationRetrieverBuilder = - new TestingInputsLocationsRetriever.Builder(); - - final ExecutionVertexID consumerId = new ExecutionVertexID(new JobVertexID(), 0); - - final int producerParallelism = ExecutionVertex.MAX_DISTINCT_LOCATIONS_TO_CONSIDER + 1; - final List<ExecutionVertexID> producerIds = new ArrayList<>(producerParallelism); - final JobVertexID producerJobVertexId = new JobVertexID(); - for (int i = 0; i < producerParallelism; i++) { - final ExecutionVertexID producerId = new ExecutionVertexID(producerJobVertexId, i); - locationRetrieverBuilder.connectConsumerToProducer(consumerId, producerId); - producerIds.add(producerId); + void testInputLocations() { + { + final List<TaskManagerLocation> producerLocations = + Collections.singletonList(new LocalTaskManagerLocation()); + testInputLocationsInternal( + 1, + MAX_DISTINCT_CONSUMERS_TO_CONSIDER, + producerLocations, + producerLocations, + Collections.emptySet()); } - - final TestingInputsLocationsRetriever inputsLocationsRetriever = - locationRetrieverBuilder.build(); - - for (int i = 0; i < producerParallelism; i++) { - inputsLocationsRetriever.markScheduled(producerIds.get(i)); + { + final List<TaskManagerLocation> producerLocations = + Arrays.asList(new LocalTaskManagerLocation(), new LocalTaskManagerLocation()); + testInputLocationsInternal( + 2, + MAX_DISTINCT_CONSUMERS_TO_CONSIDER * 2, + producerLocations, + producerLocations, + Collections.emptySet()); } + } - final PreferredLocationsRetriever locationsRetriever = - new DefaultPreferredLocationsRetriever( - id -> Optional.empty(), inputsLocationsRetriever); - - final CompletableFuture<Collection<TaskManagerLocation>> preferredLocations = - locationsRetriever.getPreferredLocations(consumerId, Collections.emptySet()); + @Test + void testInputLocationsIgnoresEdgeOfTooManyProducers() { + testNoPreferredInputLocationsInternal(MAX_DISTINCT_LOCATIONS_TO_CONSIDER + 1, 1); + } - assertThat(preferredLocations.getNow(null)).isEmpty(); + @Test + void testInputLocationsIgnoresEdgeOfTooManyConsumers() { + testNoPreferredInputLocationsInternal(1, MAX_DISTINCT_CONSUMERS_TO_CONSIDER + 1); + testNoPreferredInputLocationsInternal(2, MAX_DISTINCT_CONSUMERS_TO_CONSIDER * 2 + 1); } @Test @@ -110,8 +117,8 @@ class DefaultPreferredLocationsRetrieverTest { for (int i = 0; i < parallelism1; i++) { final ExecutionVertexID producerId = new ExecutionVertexID(jobVertexId1, i); producers1.add(producerId); - locationRetrieverBuilder.connectConsumerToProducer(consumerId, producerId); } + locationRetrieverBuilder.connectConsumerToProducers(consumerId, producers1); final JobVertexID jobVertexId2 = new JobVertexID(); int parallelism2 = 5; @@ -119,8 +126,8 @@ class DefaultPreferredLocationsRetrieverTest { for (int i = 0; i < parallelism2; i++) { final ExecutionVertexID producerId = new ExecutionVertexID(jobVertexId2, i); producers2.add(producerId); - locationRetrieverBuilder.connectConsumerToProducer(consumerId, producerId); } + locationRetrieverBuilder.connectConsumerToProducers(consumerId, producers2); final TestingInputsLocationsRetriever inputsLocationsRetriever = locationRetrieverBuilder.build(); @@ -152,40 +159,83 @@ class DefaultPreferredLocationsRetrieverTest { @Test void testInputLocationsIgnoresExcludedProducers() { - final TestingInputsLocationsRetriever.Builder locationRetrieverBuilder = - new TestingInputsLocationsRetriever.Builder(); + final List<TaskManagerLocation> producerLocations = + Arrays.asList(new LocalTaskManagerLocation(), new LocalTaskManagerLocation()); + final Set<Integer> producersToIgnore = Collections.singleton(0); + testInputLocationsInternal( + 2, 1, producerLocations, producerLocations.subList(1, 2), producersToIgnore); + } - final ExecutionVertexID consumerId = new ExecutionVertexID(new JobVertexID(), 0); + private void testNoPreferredInputLocationsInternal( + final int producerParallelism, final int consumerParallelism) { + testInputLocationsInternal( + producerParallelism, + consumerParallelism, + Collections.emptyList(), + Collections.emptyList(), + Collections.emptySet()); + } + + private void testInputLocationsInternal( + final int producerParallelism, + final int consumerParallelism, + final List<TaskManagerLocation> producerLocations, + final List<TaskManagerLocation> expectedPreferredLocations, + final Set<Integer> indicesOfProducersToIgnore) { final JobVertexID producerJobVertexId = new JobVertexID(); + final List<ExecutionVertexID> producerIds = + IntStream.range(0, producerParallelism) + .mapToObj(i -> new ExecutionVertexID(producerJobVertexId, i)) + .collect(Collectors.toList()); - final ExecutionVertexID producerId1 = new ExecutionVertexID(producerJobVertexId, 0); - locationRetrieverBuilder.connectConsumerToProducer(consumerId, producerId1); + final JobVertexID consumerJobVertexId = new JobVertexID(); + final List<ExecutionVertexID> consumerIds = + IntStream.range(0, consumerParallelism) + .mapToObj(i -> new ExecutionVertexID(consumerJobVertexId, i)) + .collect(Collectors.toList()); - final ExecutionVertexID producerId2 = new ExecutionVertexID(producerJobVertexId, 1); - locationRetrieverBuilder.connectConsumerToProducer(consumerId, producerId2); + final TestingInputsLocationsRetriever.Builder locationRetrieverBuilder = + new TestingInputsLocationsRetriever.Builder(); + locationRetrieverBuilder.connectConsumersToProducers(consumerIds, producerIds); final TestingInputsLocationsRetriever inputsLocationsRetriever = locationRetrieverBuilder.build(); + for (int i = 0; i < producerParallelism; i++) { + TaskManagerLocation producerLocation; + if (producerLocations.isEmpty()) { + // generate a random location if not specified + producerLocation = new LocalTaskManagerLocation(); + } else { + producerLocation = producerLocations.get(i); + } + inputsLocationsRetriever.assignTaskManagerLocation( + producerIds.get(i), producerLocation); + } - inputsLocationsRetriever.markScheduled(producerId1); - inputsLocationsRetriever.markScheduled(producerId2); + checkInputLocations( + consumerIds.get(0), + inputsLocationsRetriever, + expectedPreferredLocations, + indicesOfProducersToIgnore.stream() + .map(index -> new ExecutionVertexID(producerJobVertexId, index)) + .collect(Collectors.toSet())); + } - inputsLocationsRetriever.assignTaskManagerLocation(producerId1); - inputsLocationsRetriever.assignTaskManagerLocation(producerId2); + private void checkInputLocations( + final ExecutionVertexID consumerId, + final TestingInputsLocationsRetriever inputsLocationsRetriever, + final List<TaskManagerLocation> expectedPreferredLocations, + final Set<ExecutionVertexID> producersToIgnore) { final PreferredLocationsRetriever locationsRetriever = new DefaultPreferredLocationsRetriever( id -> Optional.empty(), inputsLocationsRetriever); final CompletableFuture<Collection<TaskManagerLocation>> preferredLocations = - locationsRetriever.getPreferredLocations( - consumerId, Collections.singleton(producerId1)); + locationsRetriever.getPreferredLocations(consumerId, producersToIgnore); - assertThat(preferredLocations.getNow(null)).hasSize(1); - - final TaskManagerLocation producerLocation2 = - inputsLocationsRetriever.getTaskManagerLocation(producerId2).get().getNow(null); - assertThat(preferredLocations.getNow(null)).containsExactly(producerLocation2); + assertThat(preferredLocations.getNow(null)) + .containsExactlyInAnyOrderElementsOf(expectedPreferredLocations); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSyncPreferredLocationsRetrieverTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSyncPreferredLocationsRetrieverTest.java index 43a61ae37fb..db7469d1001 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSyncPreferredLocationsRetrieverTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSyncPreferredLocationsRetrieverTest.java @@ -18,49 +18,48 @@ package org.apache.flink.runtime.scheduler; +import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.junit.jupiter.api.Test; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Optional; -import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createRandomExecutionVertexId; import static org.assertj.core.api.Assertions.assertThat; /** Tests for {@link DefaultSyncPreferredLocationsRetriever}. */ class DefaultSyncPreferredLocationsRetrieverTest { - private static final ExecutionVertexID EV1 = createRandomExecutionVertexId(); - private static final ExecutionVertexID EV2 = createRandomExecutionVertexId(); - private static final ExecutionVertexID EV3 = createRandomExecutionVertexId(); - private static final ExecutionVertexID EV4 = createRandomExecutionVertexId(); - private static final ExecutionVertexID EV5 = createRandomExecutionVertexId(); + private static final JobVertexID JV1 = new JobVertexID(); + private static final ExecutionVertexID EV11 = new ExecutionVertexID(JV1, 0); + private static final ExecutionVertexID EV12 = new ExecutionVertexID(JV1, 1); + private static final ExecutionVertexID EV13 = new ExecutionVertexID(JV1, 2); + private static final ExecutionVertexID EV14 = new ExecutionVertexID(JV1, 3); + private static final ExecutionVertexID EV21 = new ExecutionVertexID(new JobVertexID(), 0); @Test void testAvailableInputLocationRetrieval() { TestingInputsLocationsRetriever originalLocationRetriever = new TestingInputsLocationsRetriever.Builder() - .connectConsumerToProducer(EV5, EV1) - .connectConsumerToProducer(EV5, EV2) - .connectConsumerToProducer(EV5, EV3) - .connectConsumerToProducer(EV5, EV4) + .connectConsumerToProducers(EV21, Arrays.asList(EV11, EV12, EV13, EV14)) .build(); - originalLocationRetriever.assignTaskManagerLocation(EV1); - originalLocationRetriever.markScheduled(EV2); - originalLocationRetriever.failTaskManagerLocation(EV3, new Throwable()); - originalLocationRetriever.cancelTaskManagerLocation(EV4); + originalLocationRetriever.assignTaskManagerLocation(EV11); + originalLocationRetriever.markScheduled(EV12); + originalLocationRetriever.failTaskManagerLocation(EV13, new Throwable()); + originalLocationRetriever.cancelTaskManagerLocation(EV14); SyncPreferredLocationsRetriever locationsRetriever = new DefaultSyncPreferredLocationsRetriever( executionVertexId -> Optional.empty(), originalLocationRetriever); Collection<TaskManagerLocation> preferredLocations = - locationsRetriever.getPreferredLocations(EV5, Collections.emptySet()); + locationsRetriever.getPreferredLocations(EV21, Collections.emptySet()); TaskManagerLocation expectedLocation = - originalLocationRetriever.getTaskManagerLocation(EV1).get().join(); + originalLocationRetriever.getTaskManagerLocation(EV11).get().join(); assertThat(preferredLocations).containsExactly(expectedLocation); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapterTest.java index f01fb65248c..066cf9a0035 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapterTest.java @@ -24,23 +24,27 @@ import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateDataSet; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobmaster.TestingLogicalSlot; import org.apache.flink.runtime.jobmaster.TestingLogicalSlotBuilder; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.testutils.TestingUtils; import org.apache.flink.testutils.executor.TestExecutorExtension; +import org.apache.flink.util.IterableUtils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import java.util.Collection; -import java.util.Collections; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledExecutorService; +import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -52,16 +56,23 @@ class ExecutionGraphToInputsLocationsRetrieverAdapterTest { static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_EXTENSION = TestingUtils.defaultExecutorExtension(); - /** Tests that can get the producers of consumed result partitions. */ @Test - void testGetConsumedResultPartitionsProducers() throws Exception { + void testGetConsumedPartitionGroupsAndProducers() throws Exception { final JobVertex producer1 = ExecutionGraphTestUtils.createNoOpVertex(1); final JobVertex producer2 = ExecutionGraphTestUtils.createNoOpVertex(1); final JobVertex consumer = ExecutionGraphTestUtils.createNoOpVertex(1); - consumer.connectNewDataSetAsInput( - producer1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED); - consumer.connectNewDataSetAsInput( - producer2, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED); + final IntermediateDataSet dataSet1 = + consumer.connectNewDataSetAsInput( + producer1, + DistributionPattern.ALL_TO_ALL, + ResultPartitionType.PIPELINED) + .getSource(); + final IntermediateDataSet dataSet2 = + consumer.connectNewDataSetAsInput( + producer2, + DistributionPattern.ALL_TO_ALL, + ResultPartitionType.PIPELINED) + .getSource(); final ExecutionGraph eg = ExecutionGraphTestUtils.createExecutionGraph( @@ -73,20 +84,39 @@ class ExecutionGraphToInputsLocationsRetrieverAdapterTest { ExecutionVertexID evIdOfProducer2 = new ExecutionVertexID(producer2.getID(), 0); ExecutionVertexID evIdOfConsumer = new ExecutionVertexID(consumer.getID(), 0); - Collection<Collection<ExecutionVertexID>> producersOfProducer1 = - inputsLocationsRetriever.getConsumedResultPartitionsProducers(evIdOfProducer1); - Collection<Collection<ExecutionVertexID>> producersOfProducer2 = - inputsLocationsRetriever.getConsumedResultPartitionsProducers(evIdOfProducer2); - Collection<Collection<ExecutionVertexID>> producersOfConsumer = - inputsLocationsRetriever.getConsumedResultPartitionsProducers(evIdOfConsumer); - - assertThat(producersOfProducer1).isEmpty(); - assertThat(producersOfProducer2).isEmpty(); - assertThat(producersOfConsumer).hasSize(2); - assertThat(producersOfConsumer) - .containsExactlyInAnyOrder( - Collections.singletonList(evIdOfProducer1), - Collections.singletonList(evIdOfProducer2)); + Collection<ConsumedPartitionGroup> consumedPartitionGroupsOfProducer1 = + inputsLocationsRetriever.getConsumedPartitionGroups(evIdOfProducer1); + Collection<ConsumedPartitionGroup> consumedPartitionGroupsOfProducer2 = + inputsLocationsRetriever.getConsumedPartitionGroups(evIdOfProducer2); + Collection<ConsumedPartitionGroup> consumedPartitionGroupsOfConsumer = + inputsLocationsRetriever.getConsumedPartitionGroups(evIdOfConsumer); + + IntermediateResultPartitionID partitionId1 = + new IntermediateResultPartitionID(dataSet1.getId(), 0); + IntermediateResultPartitionID partitionId2 = + new IntermediateResultPartitionID(dataSet2.getId(), 0); + assertThat(consumedPartitionGroupsOfProducer1).isEmpty(); + assertThat(consumedPartitionGroupsOfProducer2).isEmpty(); + assertThat(consumedPartitionGroupsOfConsumer).hasSize(2); + assertThat( + consumedPartitionGroupsOfConsumer.stream() + .flatMap(IterableUtils::toStream) + .collect(Collectors.toSet())) + .containsExactlyInAnyOrder(partitionId1, partitionId2); + + for (ConsumedPartitionGroup consumedPartitionGroup : consumedPartitionGroupsOfConsumer) { + if (consumedPartitionGroup.getFirst().equals(partitionId1)) { + assertThat( + inputsLocationsRetriever.getProducersOfConsumedPartitionGroup( + consumedPartitionGroup)) + .containsExactly(evIdOfProducer1); + } else { + assertThat( + inputsLocationsRetriever.getProducersOfConsumedPartitionGroup( + consumedPartitionGroup)) + .containsExactly(evIdOfProducer2); + } + } } /** Tests that it will get empty task manager location if vertex is not scheduled. */ diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/TestingInputsLocationsRetriever.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/TestingInputsLocationsRetriever.java index 977a53a2a05..139bbc8b856 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/TestingInputsLocationsRetriever.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/TestingInputsLocationsRetriever.java @@ -18,10 +18,15 @@ package org.apache.flink.runtime.scheduler; -import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.scheduler.strategy.TestingSchedulingTopology; import org.apache.flink.runtime.taskmanager.LocalTaskManagerLocation; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; +import org.apache.flink.util.IterableUtils; import java.util.ArrayList; import java.util.Collection; @@ -33,28 +38,40 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; +import static org.apache.flink.runtime.scheduler.strategy.TestingSchedulingTopology.connectConsumersToProducersById; + /** A simple inputs locations retriever for testing purposes. */ class TestingInputsLocationsRetriever implements InputsLocationsRetriever { - private final Map<ExecutionVertexID, List<ExecutionVertexID>> producersByConsumer; + private final Map<ExecutionVertexID, Collection<ConsumedPartitionGroup>> + vertexToConsumedPartitionGroups; + + private final Map<IntermediateResultPartitionID, ExecutionVertexID> partitionToProducer; private final Map<ExecutionVertexID, CompletableFuture<TaskManagerLocation>> taskManagerLocationsByVertex = new HashMap<>(); TestingInputsLocationsRetriever( - final Map<ExecutionVertexID, List<ExecutionVertexID>> producersByConsumer) { - this.producersByConsumer = new HashMap<>(producersByConsumer); + final Map<ExecutionVertexID, Collection<ConsumedPartitionGroup>> + vertexToConsumedPartitionGroups, + final Map<IntermediateResultPartitionID, ExecutionVertexID> partitionToProducer) { + + this.vertexToConsumedPartitionGroups = vertexToConsumedPartitionGroups; + this.partitionToProducer = partitionToProducer; } @Override - public Collection<Collection<ExecutionVertexID>> getConsumedResultPartitionsProducers( + public Collection<ConsumedPartitionGroup> getConsumedPartitionGroups( final ExecutionVertexID executionVertexId) { - final Map<JobVertexID, List<ExecutionVertexID>> executionVerticesByJobVertex = - producersByConsumer.getOrDefault(executionVertexId, Collections.emptyList()) - .stream() - .collect(Collectors.groupingBy(ExecutionVertexID::getJobVertexId)); + return vertexToConsumedPartitionGroups.get(executionVertexId); + } - return new ArrayList<>(executionVerticesByJobVertex.values()); + @Override + public Collection<ExecutionVertexID> getProducersOfConsumedPartitionGroup( + ConsumedPartitionGroup consumedPartitionGroup) { + return IterableUtils.toStream(consumedPartitionGroup) + .map(partitionToProducer::get) + .collect(Collectors.toList()); } @Override @@ -68,13 +85,18 @@ class TestingInputsLocationsRetriever implements InputsLocationsRetriever { } public void assignTaskManagerLocation(final ExecutionVertexID executionVertexId) { + assignTaskManagerLocation(executionVertexId, new LocalTaskManagerLocation()); + } + + public void assignTaskManagerLocation( + final ExecutionVertexID executionVertexId, TaskManagerLocation location) { taskManagerLocationsByVertex.compute( executionVertexId, (key, future) -> { if (future == null) { - return CompletableFuture.completedFuture(new LocalTaskManagerLocation()); + return CompletableFuture.completedFuture(location); } - future.complete(new LocalTaskManagerLocation()); + future.complete(location); return future; }); } @@ -107,17 +129,49 @@ class TestingInputsLocationsRetriever implements InputsLocationsRetriever { static class Builder { - private final Map<ExecutionVertexID, List<ExecutionVertexID>> producersByConsumer = + private final Map<ExecutionVertexID, Collection<ConsumedPartitionGroup>> + vertexToConsumedPartitionGroups = new HashMap<>(); + + private final Map<IntermediateResultPartitionID, ExecutionVertexID> partitionToProducer = new HashMap<>(); public Builder connectConsumerToProducer( final ExecutionVertexID consumer, final ExecutionVertexID producer) { - producersByConsumer.computeIfAbsent(consumer, (key) -> new ArrayList<>()).add(producer); + return connectConsumerToProducers(consumer, Collections.singletonList(producer)); + } + + public Builder connectConsumerToProducers( + final ExecutionVertexID consumer, final List<ExecutionVertexID> producers) { + return connectConsumersToProducers(Collections.singletonList(consumer), producers); + } + + public Builder connectConsumersToProducers( + final List<ExecutionVertexID> consumers, final List<ExecutionVertexID> producers) { + TestingSchedulingTopology.ConnectionResult connectionResult = + connectConsumersToProducersById( + consumers, + producers, + new IntermediateDataSetID(), + ResultPartitionType.PIPELINED); + + for (int i = 0; i < producers.size(); i++) { + partitionToProducer.put( + connectionResult.getResultPartitions().get(i), producers.get(i)); + } + + for (ExecutionVertexID consumer : consumers) { + final Collection<ConsumedPartitionGroup> consumedPartitionGroups = + vertexToConsumedPartitionGroups.computeIfAbsent( + consumer, ignore -> new ArrayList<>()); + consumedPartitionGroups.add(connectionResult.getConsumedPartitionGroup()); + } + return this; } public TestingInputsLocationsRetriever build() { - return new TestingInputsLocationsRetriever(producersByConsumer); + return new TestingInputsLocationsRetriever( + vertexToConsumedPartitionGroups, partitionToProducer); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java index 639c846cbec..c87e3473ab3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java @@ -43,11 +43,13 @@ import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.executiongraph.IntermediateResult; +import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; import org.apache.flink.runtime.executiongraph.JobStatusListener; import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; import org.apache.flink.runtime.executiongraph.TaskExecutionStateTransition; import org.apache.flink.runtime.executiongraph.failover.flip1.ResultPartitionAvailabilityChecker; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration; @@ -314,6 +316,11 @@ class StateTrackingMockExecutionGraph implements ExecutionGraph { throw new UnsupportedOperationException(); } + @Override + public IntermediateResultPartition getResultPartitionOrThrow(IntermediateResultPartitionID id) { + throw new UnsupportedOperationException(); + } + @Override public Map<String, OptionalFailure<Accumulator<?, ?>>> aggregateUserAccumulators() { throw new UnsupportedOperationException();
