Repository: samza Updated Branches: refs/heads/master 4915baac5 -> d2c9e8162
SAMZA-1889: Extend ExecutionPlanner to support Stream-Table Joins Extend ExecutionPlanner to verify agreement in partition count among the stream(s) behind Tables â including side-input streams â and other streams participating in Stream-Table Joins. Author: Ahmed Abdul Hamid <[email protected]> Author: Ahmed Elbahtemy <[email protected]> Author: Ahmed Abdul Hamid <[email protected]> Reviewers: Xinyu Liu <[email protected]> Closes #665 from ahmedahamid/dev/ahabdulh/extend-exec-planner Project: http://git-wip-us.apache.org/repos/asf/samza/repo Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/d2c9e816 Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/d2c9e816 Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/d2c9e816 Branch: refs/heads/master Commit: d2c9e81626539016756c3a93876c9f079b77e0f4 Parents: 4915baa Author: Ahmed Abdul Hamid <[email protected]> Authored: Wed Oct 10 14:48:14 2018 -0700 Committer: xiliu <[email protected]> Committed: Wed Oct 10 14:48:14 2018 -0700 ---------------------------------------------------------------------- .../samza/execution/ExecutionPlanner.java | 144 ++++++- .../execution/IntermediateStreamManager.java | 253 +++--------- .../org/apache/samza/execution/JobGraph.java | 34 +- .../execution/OperatorSpecGraphAnalyzer.java | 134 ++++++- .../execution/ExecutionPlannerTestBase.java | 2 +- .../samza/execution/TestExecutionPlanner.java | 399 +++++++++++++++---- .../TestIntermediateStreamManager.java | 68 ---- 7 files changed, 663 insertions(+), 371 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/samza/blob/d2c9e816/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java b/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java index eea6387..b80f7df 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java +++ b/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java @@ -20,14 +20,20 @@ package org.apache.samza.execution; import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; import com.google.common.collect.Sets; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import org.apache.commons.collections4.ListUtils; import org.apache.samza.SamzaException; import org.apache.samza.application.ApplicationDescriptor; import org.apache.samza.application.ApplicationDescriptorImpl; @@ -38,12 +44,16 @@ import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; import org.apache.samza.config.StreamConfig; import org.apache.samza.operators.BaseTableDescriptor; +import org.apache.samza.operators.spec.InputOperatorSpec; +import org.apache.samza.operators.spec.OperatorSpec; +import org.apache.samza.operators.spec.StreamTableJoinOperatorSpec; import org.apache.samza.system.StreamSpec; import org.apache.samza.table.TableSpec; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static org.apache.samza.util.StreamUtil.*; +import static org.apache.samza.util.StreamUtil.getStreamSpec; +import static org.apache.samza.util.StreamUtil.getStreamSpecs; /** @@ -56,23 +66,33 @@ public class ExecutionPlanner { private final Config config; private final StreamManager streamManager; + private final StreamConfig streamConfig; public ExecutionPlanner(Config config, StreamManager streamManager) { this.config = config; this.streamManager = streamManager; + this.streamConfig = new StreamConfig(config); } public ExecutionPlan plan(ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc) { validateConfig(); - // create physical job graph based on stream graph - JobGraph jobGraph = createJobGraph(config, appDesc); + // Create physical job graph based on stream graph + JobGraph jobGraph = createJobGraph(appDesc); - // fetch the external streams partition info - setInputAndOutputStreamPartitionCount(jobGraph, streamManager); + // Fetch the external streams partition info + setInputAndOutputStreamPartitionCount(jobGraph); - // figure out the partitions for internal streams - new IntermediateStreamManager(config, appDesc).calculatePartitions(jobGraph); + // Group streams participating in joins together into sets + List<StreamSet> joinedStreamSets = groupJoinedStreams(jobGraph); + + // Set partitions of intermediate streams if any + if (!jobGraph.getIntermediateStreamEdges().isEmpty()) { + new IntermediateStreamManager(config).calculatePartitions(jobGraph, joinedStreamSets); + } + + // Verify every group of joined streams has the same partition count + joinedStreamSets.forEach(ExecutionPlanner::validatePartitions); return jobGraph; } @@ -88,12 +108,11 @@ public class ExecutionPlanner { } /** - * Create the physical graph from {@link ApplicationDescriptorImpl} + * Creates the physical graph from {@link ApplicationDescriptorImpl} */ /* package private */ - JobGraph createJobGraph(Config config, ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc) { + JobGraph createJobGraph(ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc) { JobGraph jobGraph = new JobGraph(config, appDesc); - StreamConfig streamConfig = new StreamConfig(config); // Source streams contain both input and intermediate streams. Set<StreamSpec> sourceStreams = getStreamSpecs(appDesc.getInputStreamIds(), streamConfig); // Sink streams contain both output and intermediate streams. @@ -106,7 +125,7 @@ public class ExecutionPlanner { Set<TableSpec> tables = appDesc.getTableDescriptors().stream() .map(tableDescriptor -> ((BaseTableDescriptor) tableDescriptor).getTableSpec()).collect(Collectors.toSet()); - // For this phase, we have a single job node for the whole dag + // For this phase, we have a single job node for the whole DAG String jobName = config.get(JobConfig.JOB_NAME()); String jobId = config.get(JobConfig.JOB_ID(), "1"); JobNode node = jobGraph.getOrCreateJobNode(jobName, jobId); @@ -121,7 +140,14 @@ public class ExecutionPlanner { intermediateStreams.forEach(spec -> jobGraph.addIntermediateStream(spec, node, node)); // Add tables - tables.forEach(spec -> jobGraph.addTable(spec, node)); + for (TableSpec table : tables) { + jobGraph.addTable(table, node); + // Add side-input streams (if any) + Iterable<String> sideInputs = ListUtils.emptyIfNull(table.getSideInputs()); + for (String sideInput : sideInputs) { + jobGraph.addSideInputStream(getStreamSpec(sideInput, streamConfig)); + } + } if (!LegacyTaskApplication.class.isAssignableFrom(appDesc.getAppClass())) { // skip the validation when input streamIds are empty. This is only possible for LegacyTaskApplication @@ -132,13 +158,12 @@ public class ExecutionPlanner { } /** - * Fetch the partitions of source/sink streams and update the StreamEdges. - * @param jobGraph {@link JobGraph} - * @param streamManager the {@link StreamManager} to interface with the streams. + * Fetches the partitions of input, side-input, and output streams and updates their corresponding StreamEdges. */ - /* package private */ static void setInputAndOutputStreamPartitionCount(JobGraph jobGraph, StreamManager streamManager) { + /* package private */ void setInputAndOutputStreamPartitionCount(JobGraph jobGraph) { Set<StreamEdge> existingStreams = new HashSet<>(); existingStreams.addAll(jobGraph.getInputStreams()); + existingStreams.addAll(jobGraph.getSideInputStreams()); existingStreams.addAll(jobGraph.getOutputStreams()); // System to StreamEdges @@ -152,7 +177,7 @@ public class ExecutionPlanner { // Fetch partition count for every set of StreamEdges belonging to a particular system. for (String system : systemToStreamEdges.keySet()) { - Collection<StreamEdge> streamEdges = systemToStreamEdges.get(system); + Iterable<StreamEdge> streamEdges = systemToStreamEdges.get(system); // Map every stream to its corresponding StreamEdge so we can retrieve a StreamEdge given its stream. Map<String, StreamEdge> streamToStreamEdge = new HashMap<>(); @@ -174,4 +199,89 @@ public class ExecutionPlanner { } } + /** + * Groups streams participating in joins together. + */ + private static List<StreamSet> groupJoinedStreams(JobGraph jobGraph) { + // Group input operator specs (input/intermediate streams) by the joins they participate in. + Multimap<OperatorSpec, InputOperatorSpec> joinOpSpecToInputOpSpecs = + OperatorSpecGraphAnalyzer.getJoinToInputOperatorSpecs( + jobGraph.getApplicationDescriptorImpl().getInputOperators().values()); + + // Convert every group of input operator specs into a group of corresponding stream edges. + List<StreamSet> streamSets = new ArrayList<>(); + for (OperatorSpec joinOpSpec : joinOpSpecToInputOpSpecs.keySet()) { + Collection<InputOperatorSpec> joinedInputOpSpecs = joinOpSpecToInputOpSpecs.get(joinOpSpec); + StreamSet streamSet = getStreamSet(joinOpSpec.getOpId(), joinedInputOpSpecs, jobGraph); + + // If current join is a stream-table join, add the stream edges corresponding to side-input + // streams associated with the joined table (if any). + if (joinOpSpec instanceof StreamTableJoinOperatorSpec) { + StreamTableJoinOperatorSpec streamTableJoinOperatorSpec = (StreamTableJoinOperatorSpec) joinOpSpec; + + Collection<String> sideInputs = ListUtils.emptyIfNull(streamTableJoinOperatorSpec.getTableSpec().getSideInputs()); + Iterable<StreamEdge> sideInputStreams = sideInputs.stream().map(jobGraph::getStreamEdge)::iterator; + Iterable<StreamEdge> streams = streamSet.getStreamEdges(); + streamSet = new StreamSet(streamSet.getSetId(), Iterables.concat(streams, sideInputStreams)); + } + + streamSets.add(streamSet); + } + + return Collections.unmodifiableList(streamSets); + } + + /** + * Creates a {@link StreamSet} whose Id is {@code setId}, and {@link StreamEdge}s + * correspond to the provided {@code inputOpSpecs}. + */ + private static StreamSet getStreamSet(String setId, Iterable<InputOperatorSpec> inputOpSpecs, JobGraph jobGraph) { + Set<StreamEdge> streamEdges = new HashSet<>(); + for (InputOperatorSpec inputOpSpec : inputOpSpecs) { + StreamEdge streamEdge = jobGraph.getStreamEdge(inputOpSpec.getStreamId()); + streamEdges.add(streamEdge); + } + return new StreamSet(setId, streamEdges); + } + + /** + * Verifies all {@link StreamEdge}s in the supplied {@code streamSet} agree in + * partition count, or throws. + */ + private static void validatePartitions(StreamSet streamSet) { + Collection<StreamEdge> streamEdges = streamSet.getStreamEdges(); + StreamEdge referenceStreamEdge = streamEdges.stream().findFirst().get(); + int referencePartitions = referenceStreamEdge.getPartitionCount(); + + for (StreamEdge streamEdge : streamEdges) { + int partitions = streamEdge.getPartitionCount(); + if (partitions != referencePartitions) { + throw new SamzaException(String.format( + "Unable to resolve input partitions of stream %s for the join %s. Expected: %d, Actual: %d", + referenceStreamEdge.getName(), streamSet.getSetId(), referencePartitions, partitions)); + } + } + } + + /** + * Represents a set of {@link StreamEdge}s. + */ + /* package private */ static class StreamSet { + + private final String setId; + private final Set<StreamEdge> streamEdges; + + StreamSet(String setId, Iterable<StreamEdge> streamEdges) { + this.setId = setId; + this.streamEdges = ImmutableSet.copyOf(streamEdges); + } + + Set<StreamEdge> getStreamEdges() { + return Collections.unmodifiableSet(streamEdges); + } + + String getSetId() { + return setId; + } + } } http://git-wip-us.apache.org/repos/asf/samza/blob/d2c9e816/samza-core/src/main/java/org/apache/samza/execution/IntermediateStreamManager.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/execution/IntermediateStreamManager.java b/samza-core/src/main/java/org/apache/samza/execution/IntermediateStreamManager.java index 66cbe6a..64fc7b3 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/IntermediateStreamManager.java +++ b/samza-core/src/main/java/org/apache/samza/execution/IntermediateStreamManager.java @@ -20,25 +20,21 @@ package org.apache.samza.execution; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; -import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; import java.util.HashSet; -import java.util.List; -import java.util.Map; +import java.util.Optional; import java.util.Set; import org.apache.samza.SamzaException; -import org.apache.samza.application.ApplicationDescriptor; -import org.apache.samza.application.ApplicationDescriptorImpl; import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; -import org.apache.samza.operators.spec.InputOperatorSpec; -import org.apache.samza.operators.spec.JoinOperatorSpec; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static org.apache.samza.execution.ExecutionPlanner.StreamSet; + + /** * {@link IntermediateStreamManager} calculates intermediate stream partitions based on the high-level application graph. */ @@ -47,107 +43,31 @@ class IntermediateStreamManager { private static final Logger log = LoggerFactory.getLogger(IntermediateStreamManager.class); private final Config config; - private final Map<String, InputOperatorSpec> inputOperators; @VisibleForTesting static final int MAX_INFERRED_PARTITIONS = 256; - IntermediateStreamManager(Config config, ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc) { + IntermediateStreamManager(Config config) { this.config = config; - this.inputOperators = appDesc.getInputOperators(); - } - - /** - * Figure out the number of partitions of all streams - */ - /* package private */ void calculatePartitions(JobGraph jobGraph) { - - // Verify agreement in partition count between all joined input/intermediate streams - validateJoinInputStreamPartitions(jobGraph); - - if (!jobGraph.getIntermediateStreamEdges().isEmpty()) { - // Set partition count of intermediate streams not participating in joins - setIntermediateStreamPartitions(jobGraph); - - // Validate partition counts were assigned for all intermediate streams - validateIntermediateStreamPartitions(jobGraph); - } } /** - * Validates agreement in partition count between input/intermediate streams participating in join operations. + * Calculates the number of partitions of all intermediate streams */ - private void validateJoinInputStreamPartitions(JobGraph jobGraph) { - // Group input operator specs (input/intermediate streams) by the joins they participate in. - Multimap<JoinOperatorSpec, InputOperatorSpec> joinOpSpecToInputOpSpecs = - OperatorSpecGraphAnalyzer.getJoinToInputOperatorSpecs(inputOperators.values()); + /* package private */ void calculatePartitions(JobGraph jobGraph, Collection<StreamSet> joinedStreamSets) { - // Convert every group of input operator specs into a group of corresponding stream edges. - List<StreamEdgeSet> streamEdgeSets = new ArrayList<>(); - for (JoinOperatorSpec joinOpSpec : joinOpSpecToInputOpSpecs.keySet()) { - Collection<InputOperatorSpec> joinedInputOpSpecs = joinOpSpecToInputOpSpecs.get(joinOpSpec); - StreamEdgeSet streamEdgeSet = getStreamEdgeSet(joinOpSpec.getOpId(), joinedInputOpSpecs, jobGraph); - streamEdgeSets.add(streamEdgeSet); - } + // Set partition count of intermediate streams participating in joins + setJoinedIntermediateStreamPartitions(joinedStreamSets); - /* - * Sort the stream edge groups by their category so they appear in this order: - * 1. groups composed exclusively of stream edges with set partition counts - * 2. groups composed of a mix of stream edges with set/unset partition counts - * 3. groups composed exclusively of stream edges with unset partition counts - * - * This guarantees that we process the most constrained stream edge groups first, - * which is crucial for intermediate stream edges that are members of multiple - * stream edge groups. For instance, if we have the following groups of stream - * edges (partition counts in parentheses, question marks for intermediate streams): - * - * a. e1 (16), e2 (16) - * b. e2 (16), e3 (?) - * c. e3 (?), e4 (?) - * - * processing them in the above order (most constrained first) is guaranteed to - * yield correct assignment of partition counts of e3 and e4 in a single scan. - */ - Collections.sort(streamEdgeSets, Comparator.comparingInt(e -> e.getCategory().getSortOrder())); + // Set partition count of intermediate streams not participating in joins + setIntermediateStreamPartitions(jobGraph); - // Verify agreement between joined input/intermediate streams. - // This may involve setting partition counts of intermediate stream edges. - streamEdgeSets.forEach(IntermediateStreamManager::validateAndAssignStreamEdgeSetPartitions); + // Validate partition counts were assigned for all intermediate streams + validateIntermediateStreamPartitions(jobGraph); } /** - * Creates a {@link StreamEdgeSet} whose Id is {@code setId}, and {@link StreamEdge}s - * correspond to the provided {@code inputOpSpecs}. - */ - private StreamEdgeSet getStreamEdgeSet(String setId, Iterable<InputOperatorSpec> inputOpSpecs, - JobGraph jobGraph) { - - int countStreamEdgeWithSetPartitions = 0; - Set<StreamEdge> streamEdges = new HashSet<>(); - - for (InputOperatorSpec inputOpSpec : inputOpSpecs) { - StreamEdge streamEdge = jobGraph.getStreamEdge(inputOpSpec.getStreamId()); - if (streamEdge.getPartitionCount() != StreamEdge.PARTITIONS_UNKNOWN) { - ++countStreamEdgeWithSetPartitions; - } - streamEdges.add(streamEdge); - } - - // Determine category of stream group based on stream partition counts. - StreamEdgeSet.StreamEdgeSetCategory category; - if (countStreamEdgeWithSetPartitions == 0) { - category = StreamEdgeSet.StreamEdgeSetCategory.NO_PARTITION_COUNT_SET; - } else if (countStreamEdgeWithSetPartitions == streamEdges.size()) { - category = StreamEdgeSet.StreamEdgeSetCategory.ALL_PARTITION_COUNT_SET; - } else { - category = StreamEdgeSet.StreamEdgeSetCategory.SOME_PARTITION_COUNT_SET; - } - - return new StreamEdgeSet(setId, streamEdges, category); - } - - /** - * Sets partition count of intermediate streams which have not been assigned partition counts. + * Sets partition counts of intermediate streams which have not been assigned partition counts. */ private void setIntermediateStreamPartitions(JobGraph jobGraph) { final String defaultPartitionsConfigProperty = JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(); @@ -186,112 +106,67 @@ class IntermediateStreamManager { } /** - * Ensures all intermediate streams have been assigned partition counts. + * Sets partition counts of intermediate streams participating in joins operations. */ - private static void validateIntermediateStreamPartitions(JobGraph jobGraph) { - for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) { - if (edge.getPartitionCount() <= 0) { - throw new SamzaException(String.format("Failed to assign valid partition count to Stream %s", edge.getName())); + private static void setJoinedIntermediateStreamPartitions(Collection<StreamSet> joinedStreamSets) { + // Map every intermediate stream to all the stream-sets it appears in + Multimap<StreamEdge, StreamSet> intermediateStreamToStreamSets = HashMultimap.create(); + for (StreamSet streamSet : joinedStreamSets) { + for (StreamEdge streamEdge : streamSet.getStreamEdges()) { + if (streamEdge.getPartitionCount() == StreamEdge.PARTITIONS_UNKNOWN) { + intermediateStreamToStreamSets.put(streamEdge, streamSet); + } } } - } - - /** - * Ensures that all streams in the supplied {@link StreamEdgeSet} agree in partition count. - * This may include setting partition counts of intermediate streams in this set that do not - * have their partition counts set. - */ - private static void validateAndAssignStreamEdgeSetPartitions(StreamEdgeSet streamEdgeSet) { - Set<StreamEdge> streamEdges = streamEdgeSet.getStreamEdges(); - StreamEdge firstStreamEdgeWithSetPartitions = - streamEdges.stream() - .filter(streamEdge -> streamEdge.getPartitionCount() != StreamEdge.PARTITIONS_UNKNOWN) - .findFirst() - .orElse(null); - - // This group consists exclusively of intermediate streams with unknown partition counts. - // We cannot do any validation/computation of partition counts of such streams right here, - // but they are tackled later in the ExecutionPlanner. - if (firstStreamEdgeWithSetPartitions == null) { - return; - } - // Make sure all other stream edges in this group have the same partition count. - int partitions = firstStreamEdgeWithSetPartitions.getPartitionCount(); - for (StreamEdge streamEdge : streamEdges) { - int streamPartitions = streamEdge.getPartitionCount(); - if (streamPartitions == StreamEdge.PARTITIONS_UNKNOWN) { - streamEdge.setPartitionCount(partitions); - log.info("Inferred the partition count {} for the join operator {} from {}.", - new Object[] {partitions, streamEdgeSet.getSetId(), firstStreamEdgeWithSetPartitions.getName()}); - } else if (streamPartitions != partitions) { - throw new SamzaException(String.format( - "Unable to resolve input partitions of stream %s for the join %s. Expected: %d, Actual: %d", - streamEdge.getName(), streamEdgeSet.getSetId(), partitions, streamPartitions)); + Set<StreamSet> streamSets = new HashSet<>(joinedStreamSets); + Set<StreamSet> processedStreamSets = new HashSet<>(); + + while (!streamSets.isEmpty()) { + // Retrieve and remove one stream set + StreamSet streamSet = streamSets.iterator().next(); + streamSets.remove(streamSet); + + // Find any stream with set partitions in this set + Optional<StreamEdge> streamWithSetPartitions = + streamSet.getStreamEdges().stream() + .filter(streamEdge -> streamEdge.getPartitionCount() != StreamEdge.PARTITIONS_UNKNOWN) + .findAny(); + + if (streamWithSetPartitions.isPresent()) { + // Mark this stream-set as processed since we won't need to re-examine it ever again. + // It is important that we do this first before processing any intermediate streams + // that may be in this stream-set. + processedStreamSets.add(streamSet); + + // Set partitions of all intermediate streams in this set (if any) + int partitions = streamWithSetPartitions.get().getPartitionCount(); + for (StreamEdge streamEdge : streamSet.getStreamEdges()) { + if (streamEdge.getPartitionCount() == StreamEdge.PARTITIONS_UNKNOWN) { + streamEdge.setPartitionCount(partitions); + // Add all unprocessed stream-sets in which this intermediate stream appears + Collection<StreamSet> streamSetsIncludingIntStream = intermediateStreamToStreamSets.get(streamEdge); + streamSetsIncludingIntStream.stream() + .filter(s -> !processedStreamSets.contains(s)) + .forEach(streamSets::add); + } + } } } } - /* package private */ static int maxPartitions(Collection<StreamEdge> edges) { - return edges.stream().mapToInt(StreamEdge::getPartitionCount).max().orElse(StreamEdge.PARTITIONS_UNKNOWN); - } - /** - * Represents a set of {@link StreamEdge}s. + * Ensures all intermediate streams have been assigned partition counts. */ - /* package private */ static class StreamEdgeSet { - - /** - * Indicates whether all stream edges in this group have their partition counts assigned. - */ - public enum StreamEdgeSetCategory { - /** - * All stream edges in this group have their partition counts assigned. - */ - ALL_PARTITION_COUNT_SET(0), - - /** - * Only some stream edges in this group have their partition counts assigned. - */ - SOME_PARTITION_COUNT_SET(1), - - /** - * No stream edge in this group is assigned a partition count. - */ - NO_PARTITION_COUNT_SET(2); - - - private final int sortOrder; - - StreamEdgeSetCategory(int sortOrder) { - this.sortOrder = sortOrder; - } - - public int getSortOrder() { - return sortOrder; + private static void validateIntermediateStreamPartitions(JobGraph jobGraph) { + for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) { + if (edge.getPartitionCount() <= 0) { + throw new SamzaException(String.format("Failed to assign valid partition count to Stream %s", edge.getName())); } } + } - private final String setId; - private final Set<StreamEdge> streamEdges; - private final StreamEdgeSetCategory category; - - StreamEdgeSet(String setId, Set<StreamEdge> streamEdges, StreamEdgeSetCategory category) { - this.setId = setId; - this.streamEdges = streamEdges; - this.category = category; - } - - Set<StreamEdge> getStreamEdges() { - return streamEdges; - } - - String getSetId() { - return setId; - } - - StreamEdgeSetCategory getCategory() { - return category; - } + /* package private */ static int maxPartitions(Collection<StreamEdge> edges) { + return edges.stream().mapToInt(StreamEdge::getPartitionCount).max().orElse(StreamEdge.PARTITIONS_UNKNOWN); } } http://git-wip-us.apache.org/repos/asf/samza/blob/d2c9e816/samza-core/src/main/java/org/apache/samza/execution/JobGraph.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/execution/JobGraph.java b/samza-core/src/main/java/org/apache/samza/execution/JobGraph.java index d975188..f43b24e 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/JobGraph.java +++ b/samza-core/src/main/java/org/apache/samza/execution/JobGraph.java @@ -30,7 +30,6 @@ import java.util.Map; import java.util.Queue; import java.util.Set; import java.util.stream.Collectors; - import org.apache.samza.application.ApplicationDescriptor; import org.apache.samza.application.ApplicationDescriptorImpl; import org.apache.samza.config.ApplicationConfig; @@ -58,6 +57,7 @@ import org.slf4j.LoggerFactory; private final Set<StreamEdge> inputStreams = new HashSet<>(); private final Set<StreamEdge> outputStreams = new HashSet<>(); private final Set<StreamEdge> intermediateStreams = new HashSet<>(); + private final Set<StreamEdge> sideInputStreams = new HashSet<>(); private final Set<TableSpec> tables = new HashSet<>(); private final Config config; private final JobGraphJsonGenerator jsonGenerator; @@ -156,6 +156,15 @@ import org.slf4j.LoggerFactory; } /** + * Add a side-input stream to graph + * @param streamSpec side-input stream + */ + void addSideInputStream(StreamSpec streamSpec) { + StreamEdge edge = getOrCreateStreamEdge(streamSpec, false); + sideInputStreams.add(edge); + } + + /** * Get the {@link JobNode}. Create one if it does not exist. * @param jobName name of the job * @param jobId id of the job @@ -176,6 +185,14 @@ import org.slf4j.LoggerFactory; } /** + * Returns the {@link ApplicationDescriptorImpl} of this graph. + * @return Application descriptor implementation + */ + ApplicationDescriptorImpl<? extends ApplicationDescriptor> getApplicationDescriptorImpl() { + return appDesc; + } + + /** * Get the {@link StreamEdge} for {@code streamId}. * * @param streamId the streamId for the {@link StreamEdge} @@ -203,7 +220,15 @@ import org.slf4j.LoggerFactory; } /** - * Return the output streams in the graph + * Returns the side-input streams in the graph + * @return unmodifiable set of {@link StreamEdge} + */ + Set<StreamEdge> getSideInputStreams() { + return Collections.unmodifiableSet(sideInputStreams); + } + + /** + * Returns the output streams in the graph * @return unmodifiable set of {@link StreamEdge} */ Set<StreamEdge> getOutputStreams() { @@ -211,7 +236,7 @@ import org.slf4j.LoggerFactory; } /** - * Return the tables in the graph + * Returns the tables in the graph * @return unmodifiable set of {@link TableSpec} */ Set<TableSpec> getTables() { @@ -219,7 +244,7 @@ import org.slf4j.LoggerFactory; } /** - * Return the intermediate streams in the graph + * Returns the intermediate streams in the graph * @return unmodifiable set of {@link StreamEdge} */ Set<StreamEdge> getIntermediateStreamEdges() { @@ -293,6 +318,7 @@ import org.slf4j.LoggerFactory; private void validateInternalStreams() { Set<StreamEdge> internalEdges = new HashSet<>(edges.values()); internalEdges.removeAll(inputStreams); + internalEdges.removeAll(sideInputStreams); internalEdges.removeAll(outputStreams); internalEdges.forEach(edge -> { http://git-wip-us.apache.org/repos/asf/samza/blob/d2c9e816/samza-core/src/main/java/org/apache/samza/execution/OperatorSpecGraphAnalyzer.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/execution/OperatorSpecGraphAnalyzer.java b/samza-core/src/main/java/org/apache/samza/execution/OperatorSpecGraphAnalyzer.java index ca91214..123244b 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/OperatorSpecGraphAnalyzer.java +++ b/samza-core/src/main/java/org/apache/samza/execution/OperatorSpecGraphAnalyzer.java @@ -21,6 +21,7 @@ package org.apache.samza.execution; import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; +import com.google.common.collect.Multimaps; import java.util.Collection; import java.util.Collections; import java.util.HashSet; @@ -30,6 +31,9 @@ import java.util.function.Function; import org.apache.samza.operators.spec.InputOperatorSpec; import org.apache.samza.operators.spec.JoinOperatorSpec; import org.apache.samza.operators.spec.OperatorSpec; +import org.apache.samza.operators.spec.SendToTableOperatorSpec; +import org.apache.samza.operators.spec.StreamTableJoinOperatorSpec; +import org.apache.samza.table.TableSpec; /** @@ -39,27 +43,40 @@ import org.apache.samza.operators.spec.OperatorSpec; /* package private */ class OperatorSpecGraphAnalyzer { /** - * Returns a grouping of {@link InputOperatorSpec}s by the joins, i.e. {@link JoinOperatorSpec}s, they participate in. + * Returns a grouping of {@link InputOperatorSpec}s by the joins, i.e. {@link JoinOperatorSpec}s and + * {@link StreamTableJoinOperatorSpec}s, they participate in. + * + * The key of the returned Multimap is of type {@link OperatorSpec} due to the lack of a stricter + * base type for {@link JoinOperatorSpec} and {@link StreamTableJoinOperatorSpec}. However, key + * objects are guaranteed to be of either type only. */ - public static Multimap<JoinOperatorSpec, InputOperatorSpec> getJoinToInputOperatorSpecs( - Collection<InputOperatorSpec> inputOperatorSpecs) { + public static Multimap<OperatorSpec, InputOperatorSpec> getJoinToInputOperatorSpecs( + Collection<InputOperatorSpec> inputOpSpecs) { - Multimap<JoinOperatorSpec, InputOperatorSpec> joinOpSpecToInputOpSpecs = HashMultimap.create(); + Multimap<OperatorSpec, InputOperatorSpec> joinToInputOpSpecs = HashMultimap.create(); + + // Create a getNextOpSpecs() function that emulates connections between every SendToTableOperatorSpec + // â which are terminal OperatorSpecs â and all StreamTableJoinOperatorSpecs referencing the same TableSpec. + // + // This is necessary to support Stream-Table Join scenarios because it allows us to associate streams behind + // SendToTableOperatorSpecs with streams participating in Stream-Table Joins, a connection that would not be + // easy to make otherwise since SendToTableOperatorSpecs are terminal operator specs. + Function<OperatorSpec, Iterable<OperatorSpec>> getNextOpSpecs = getCustomGetNextOpSpecs(inputOpSpecs); // Traverse graph starting from every input operator spec, observing connectivity between input operator specs - // and Join operator specs. - for (InputOperatorSpec inputOpSpec : inputOperatorSpecs) { - // Observe all join operator specs reachable from this input operator spec. - JoinOperatorSpecVisitor joinOperatorSpecVisitor = new JoinOperatorSpecVisitor(); - traverse(inputOpSpec, joinOperatorSpecVisitor, opSpec -> opSpec.getRegisteredOperatorSpecs()); - - // Associate every encountered join operator spec with this input operator spec. - for (JoinOperatorSpec joinOpSpec : joinOperatorSpecVisitor.getJoinOperatorSpecs()) { - joinOpSpecToInputOpSpecs.put(joinOpSpec, inputOpSpec); + // and join-related operator specs. + for (InputOperatorSpec inputOpSpec : inputOpSpecs) { + // Observe all join-related operator specs reachable from this input operator spec. + JoinVisitor joinVisitor = new JoinVisitor(); + traverse(inputOpSpec, joinVisitor, getNextOpSpecs); + + // Associate every encountered join-related operator spec with this input operator spec. + for (OperatorSpec joinOpSpec : joinVisitor.getJoins()) { + joinToInputOpSpecs.put(joinOpSpec, inputOpSpec); } } - return joinOpSpecToInputOpSpecs; + return joinToInputOpSpecs; } /** @@ -67,7 +84,7 @@ import org.apache.samza.operators.spec.OperatorSpec; * {@link OperatorSpec}, and using {@code getNextOpSpecs} to determine the set of {@link OperatorSpec}s to visit next. */ private static void traverse(OperatorSpec startOpSpec, Consumer<OperatorSpec> visitor, - Function<OperatorSpec, Collection<OperatorSpec>> getNextOpSpecs) { + Function<OperatorSpec, Iterable<OperatorSpec>> getNextOpSpecs) { visitor.accept(startOpSpec); for (OperatorSpec nextOpSpec : getNextOpSpecs.apply(startOpSpec)) { traverse(nextOpSpec, visitor, getNextOpSpecs); @@ -75,20 +92,93 @@ import org.apache.samza.operators.spec.OperatorSpec; } /** - * An visitor that records all {@link JoinOperatorSpec}s encountered in the graph of {@link OperatorSpec}s + * Creates a function that retrieves the next {@link OperatorSpec}s of any given {@link OperatorSpec} in the specified + * {@code operatorSpecGraph}. + * + * Calling the returned function with any {@link SendToTableOperatorSpec} will return a collection of all + * {@link StreamTableJoinOperatorSpec}s that reference the same {@link TableSpec} as the specified + * {@link SendToTableOperatorSpec}, as if they were actually connected. */ - private static class JoinOperatorSpecVisitor implements Consumer<OperatorSpec> { - private Set<JoinOperatorSpec> joinOpSpecs = new HashSet<>(); + private static Function<OperatorSpec, Iterable<OperatorSpec>> getCustomGetNextOpSpecs( + Iterable<InputOperatorSpec> inputOpSpecs) { + + // Traverse operatorSpecGraph to create mapping between every SendToTableOperatorSpec and all + // StreamTableJoinOperatorSpecs referencing the same TableSpec. + TableJoinVisitor tableJoinVisitor = new TableJoinVisitor(); + for (InputOperatorSpec inputOpSpec : inputOpSpecs) { + traverse(inputOpSpec, tableJoinVisitor, opSpec -> opSpec.getRegisteredOperatorSpecs()); + } + + Multimap<SendToTableOperatorSpec, StreamTableJoinOperatorSpec> sendToTableOpSpecToStreamTableJoinOpSpecs = + tableJoinVisitor.getSendToTableOpSpecToStreamTableJoinOpSpecs(); + + return operatorSpec -> { + // If this is a SendToTableOperatorSpec, return all StreamTableJoinSpecs referencing the same TableSpec. + // For all other types of operator specs, return the next registered operator specs. + if (operatorSpec instanceof SendToTableOperatorSpec) { + SendToTableOperatorSpec sendToTableOperatorSpec = (SendToTableOperatorSpec) operatorSpec; + return Collections.unmodifiableCollection(sendToTableOpSpecToStreamTableJoinOpSpecs.get(sendToTableOperatorSpec)); + } + + return operatorSpec.getRegisteredOperatorSpecs(); + }; + } + + /** + * An {@link OperatorSpec} visitor that records all {@link JoinOperatorSpec}s and {@link StreamTableJoinOperatorSpec}s + * encountered in the graph. + */ + private static class JoinVisitor implements Consumer<OperatorSpec> { + private Set<OperatorSpec> joinOpSpecs = new HashSet<>(); @Override - public void accept(OperatorSpec operatorSpec) { - if (operatorSpec instanceof JoinOperatorSpec) { - joinOpSpecs.add((JoinOperatorSpec) operatorSpec); + public void accept(OperatorSpec opSpec) { + if (opSpec instanceof JoinOperatorSpec || opSpec instanceof StreamTableJoinOperatorSpec) { + joinOpSpecs.add(opSpec); } } - public Set<JoinOperatorSpec> getJoinOperatorSpecs() { + public Set<OperatorSpec> getJoins() { return Collections.unmodifiableSet(joinOpSpecs); } } + + /** + * An {@link OperatorSpec} visitor that records associations between every {@link SendToTableOperatorSpec} + * and all {@link StreamTableJoinOperatorSpec}s that reference the same {@link TableSpec}. + */ + private static class TableJoinVisitor implements Consumer<OperatorSpec> { + private final Multimap<TableSpec, SendToTableOperatorSpec> tableSpecToSendToTableOpSpecs = HashMultimap.create(); + private final Multimap<TableSpec, StreamTableJoinOperatorSpec> tableSpecToStreamTableJoinOpSpecs = HashMultimap.create(); + + @Override + public void accept(OperatorSpec opSpec) { + // Record all SendToTableOperatorSpecs, StreamTableJoinOperatorSpecs, and their corresponding TableSpecs. + if (opSpec instanceof SendToTableOperatorSpec) { + SendToTableOperatorSpec sendToTableOperatorSpec = (SendToTableOperatorSpec) opSpec; + tableSpecToSendToTableOpSpecs.put(sendToTableOperatorSpec.getTableSpec(), sendToTableOperatorSpec); + } else if (opSpec instanceof StreamTableJoinOperatorSpec) { + StreamTableJoinOperatorSpec streamTableJoinOpSpec = (StreamTableJoinOperatorSpec) opSpec; + tableSpecToStreamTableJoinOpSpecs.put(streamTableJoinOpSpec.getTableSpec(), streamTableJoinOpSpec); + } + } + + public Multimap<SendToTableOperatorSpec, StreamTableJoinOperatorSpec> getSendToTableOpSpecToStreamTableJoinOpSpecs() { + Multimap<SendToTableOperatorSpec, StreamTableJoinOperatorSpec> sendToTableOpSpecToStreamTableJoinOpSpecs = + HashMultimap.create(); + + // Map every SendToTableOperatorSpec to all StreamTableJoinOperatorSpecs referencing the same TableSpec. + for (TableSpec tableSpec : tableSpecToSendToTableOpSpecs.keySet()) { + Collection<SendToTableOperatorSpec> sendToTableOpSpecs = tableSpecToSendToTableOpSpecs.get(tableSpec); + Collection<StreamTableJoinOperatorSpec> streamTableJoinOpSpecs = + tableSpecToStreamTableJoinOpSpecs.get(tableSpec); + + for (SendToTableOperatorSpec sendToTableOpSpec : sendToTableOpSpecs) { + sendToTableOpSpecToStreamTableJoinOpSpecs.putAll(sendToTableOpSpec, streamTableJoinOpSpecs); + } + } + + return Multimaps.unmodifiableMultimap(sendToTableOpSpecToStreamTableJoinOpSpecs); + } + } } http://git-wip-us.apache.org/repos/asf/samza/blob/d2c9e816/samza-core/src/test/java/org/apache/samza/execution/ExecutionPlannerTestBase.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/execution/ExecutionPlannerTestBase.java b/samza-core/src/test/java/org/apache/samza/execution/ExecutionPlannerTestBase.java index d172005..6308589 100644 --- a/samza-core/src/test/java/org/apache/samza/execution/ExecutionPlannerTestBase.java +++ b/samza-core/src/test/java/org/apache/samza/execution/ExecutionPlannerTestBase.java @@ -102,7 +102,7 @@ class ExecutionPlannerTestBase { void configureJobNode(ApplicationDescriptorImpl mockStreamAppDesc) { JobGraph jobGraph = new ExecutionPlanner(mockConfig, mock(StreamManager.class)) - .createJobGraph(mockConfig, mockStreamAppDesc); + .createJobGraph(mockStreamAppDesc); mockJobNode = spy(jobGraph.getJobNodes().get(0)); } http://git-wip-us.apache.org/repos/asf/samza/blob/d2c9e816/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java b/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java index 213efde..4cfcfd2 100644 --- a/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java +++ b/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java @@ -21,6 +21,7 @@ package org.apache.samza.execution; import java.time.Duration; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -40,6 +41,7 @@ import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; import org.apache.samza.config.MapConfig; import org.apache.samza.config.TaskConfig; +import org.apache.samza.operators.BaseTableDescriptor; import org.apache.samza.operators.KV; import org.apache.samza.operators.MessageStream; import org.apache.samza.operators.OutputStream; @@ -51,15 +53,19 @@ import org.apache.samza.operators.descriptors.base.stream.InputDescriptor; import org.apache.samza.operators.descriptors.base.stream.OutputDescriptor; import org.apache.samza.operators.descriptors.base.system.SystemDescriptor; import org.apache.samza.operators.functions.JoinFunction; +import org.apache.samza.operators.functions.StreamTableJoinFunction; import org.apache.samza.operators.windows.Windows; import org.apache.samza.serializers.KVSerde; import org.apache.samza.serializers.NoOpSerde; import org.apache.samza.serializers.Serde; +import org.apache.samza.storage.SideInputsProcessor; import org.apache.samza.system.StreamSpec; import org.apache.samza.system.SystemAdmin; import org.apache.samza.system.SystemAdmins; import org.apache.samza.system.SystemStreamMetadata; import org.apache.samza.system.SystemStreamPartition; +import org.apache.samza.table.Table; +import org.apache.samza.table.TableSpec; import org.apache.samza.testUtils.StreamTestUtils; import org.junit.Before; import org.junit.Test; @@ -145,7 +151,7 @@ public class TestExecutionPlanner { }, config); } - private StreamApplicationDescriptorImpl createStreamGraphWithJoin() { + private StreamApplicationDescriptorImpl createStreamGraphWithStreamStreamJoin() { /** * the graph looks like the following. number of partitions in parentheses. quotes indicate expected value. @@ -175,17 +181,38 @@ public class TestExecutionPlanner { messageStream1 .join(messageStream2, - (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class), - mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(2), "j1") + mock(JoinFunction.class), mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(2), "j1") .sendTo(output1); messageStream3 .join(messageStream2, - (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class), - mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j2") + mock(JoinFunction.class), mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j2") .sendTo(output2); }, config); } + private StreamApplicationDescriptorImpl createStreamGraphWithInvalidStreamStreamJoin() { + /** + * Creates the following stream-stream join which is invalid due to partition count disagreement + * between the 2 input streams. + * + * input1 (64) -- + * | + * join -> output1 (8) + * | + * input3 (32) -- + */ + return new StreamApplicationDescriptorImpl(appDesc -> { + MessageStream<KV<Object, Object>> messageStream1 = appDesc.getInputStream(input1Descriptor); + MessageStream<KV<Object, Object>> messageStream3 = appDesc.getInputStream(input3Descriptor); + OutputStream<KV<Object, Object>> output1 = appDesc.getOutputStream(output1Descriptor); + + messageStream1 + .join(messageStream3, + mock(JoinFunction.class), mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(2), "j1") + .sendTo(output1); + }, config); + } + private StreamApplicationDescriptorImpl createStreamGraphWithJoinAndWindow() { return new StreamApplicationDescriptorImpl(appDesc -> { @@ -210,33 +237,180 @@ public class TestExecutionPlanner { .filter(m -> true) .window(Windows.keyedTumblingWindow(m -> m, Duration.ofMillis(16), (Serde<KV<Object, Object>>) mock(Serde.class), (Serde<KV<Object, Object>>) mock(Serde.class)), "w2"); - messageStream1.join(messageStream2, (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class), - mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(1600), "j1").sendTo(output1); - messageStream3.join(messageStream2, (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class), - mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(100), "j2").sendTo(output2); - messageStream3.join(messageStream2, (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class), - mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(252), "j3").sendTo(output2); + messageStream1.join(messageStream2, mock(JoinFunction.class), mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(1600), "j1").sendTo(output1); + messageStream3.join(messageStream2, mock(JoinFunction.class), mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(100), "j2").sendTo(output2); + messageStream3.join(messageStream2, mock(JoinFunction.class), mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(252), "j3").sendTo(output2); }, config); } - private StreamApplicationDescriptorImpl createStreamGraphWithInvalidJoin() { + private StreamApplicationDescriptorImpl createStreamGraphWithStreamTableJoin() { /** - * input1 (64) -- - * | - * join -> output1 (8) - * | - * input3 (32) -- + * Example stream-table join app. Expected partition counts of intermediate streams introduced + * by partitionBy operations are enclosed in quotes. + * + * input2 (16) -> partitionBy ("32") -> send-to-table t + * + * join-table t âââââ + * | | + * input1 (64) -> partitionBy ("32") _| | + * join -> output1 (8) + * | + * input3 (32) ââââââ + * */ return new StreamApplicationDescriptorImpl(appDesc -> { MessageStream<KV<Object, Object>> messageStream1 = appDesc.getInputStream(input1Descriptor); + MessageStream<KV<Object, Object>> messageStream2 = appDesc.getInputStream(input2Descriptor); MessageStream<KV<Object, Object>> messageStream3 = appDesc.getInputStream(input3Descriptor); OutputStream<KV<Object, Object>> output1 = appDesc.getOutputStream(output1Descriptor); + TableDescriptor tableDescriptor = new TestTableDescriptor("table-id"); + Table table = appDesc.getTable(tableDescriptor); + + messageStream2 + .partitionBy(m -> m.key, m -> m.value, mock(KVSerde.class), "p1") + .sendTo(table); + messageStream1 - .join(messageStream3, - (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class), - mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(2), "j1") - .sendTo(output1); + .partitionBy(m -> m.key, m -> m.value, mock(KVSerde.class), "p2") + .join(table, mock(StreamTableJoinFunction.class)) + .join(messageStream3, + mock(JoinFunction.class), mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j2") + .sendTo(output1); + }, config); + } + + private StreamApplicationDescriptorImpl createStreamGraphWithComplexStreamStreamJoin() { + /** + * Example stream-table join app. Expected partition counts of intermediate streams introduced + * by partitionBy operations are enclosed in quotes. + * + * input1 (64) ________________________ + * | + * join âââââ output1 (8) + * | + * input2 (16) -> partitionBy ("64") --| + * | + * join âââââ output1 (8) + * | + * input3 (32) -> partitionBy ("64") --| + * | + * join âââââ output1 (8) + * | + * input4 (512) -> partitionBy ("64") __| + * + * + */ + return new StreamApplicationDescriptorImpl(appDesc -> { + MessageStream<KV<Object, Object>> messageStream1 = appDesc.getInputStream(input1Descriptor); + + MessageStream<KV<Object, Object>> messageStream2 = + appDesc.getInputStream(input2Descriptor) + .partitionBy(m -> m.key, m -> m.value, mock(KVSerde.class), "p2"); + + MessageStream<KV<Object, Object>> messageStream3 = + appDesc.getInputStream(input3Descriptor) + .partitionBy(m -> m.key, m -> m.value, mock(KVSerde.class), "p3"); + + MessageStream<KV<Object, Object>> messageStream4 = + appDesc.getInputStream(input4Descriptor) + .partitionBy(m -> m.key, m -> m.value, mock(KVSerde.class), "p4"); + + OutputStream<KV<Object, Object>> output1 = appDesc.getOutputStream(output1Descriptor); + + messageStream1 + .join(messageStream2, + mock(JoinFunction.class), mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j1") + .sendTo(output1); + + messageStream3 + .join(messageStream4, + mock(JoinFunction.class), mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j2") + .sendTo(output1); + + messageStream2 + .join(messageStream3, + mock(JoinFunction.class), mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j3") + .sendTo(output1); + }, config); + } + + private StreamApplicationDescriptorImpl createStreamGraphWithInvalidStreamTableJoin() { + /** + * Example stream-table join that is invalid due to disagreement in partition count + * between the 2 input streams. + * + * input1 (64) -> send-to-table t + * + * join-table t -> output1 (8) + * | + * input2 (16) âââââââââ + * + */ + return new StreamApplicationDescriptorImpl(appDesc -> { + MessageStream<KV<Object, Object>> messageStream1 = appDesc.getInputStream(input1Descriptor); + MessageStream<KV<Object, Object>> messageStream2 = appDesc.getInputStream(input2Descriptor); + OutputStream<KV<Object, Object>> output1 = appDesc.getOutputStream(output1Descriptor); + + TableDescriptor tableDescriptor = new TestTableDescriptor("table-id"); + Table table = appDesc.getTable(tableDescriptor); + + messageStream1.sendTo(table); + + messageStream1 + .join(table, mock(StreamTableJoinFunction.class)) + .join(messageStream2, + mock(JoinFunction.class), mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j2") + .sendTo(output1); + }, config); + } + + private StreamApplicationDescriptorImpl createStreamGraphWithStreamTableJoinWithSideInputs() { + /** + * Example stream-table join where table t is configured with input1 (64) as a side-input stream. + * + * join-table t -> output1 (8) + * | + * input2 (16) -> partitionBy ("64") __| + * + */ + return new StreamApplicationDescriptorImpl(appDesc -> { + MessageStream<KV<Object, Object>> messageStream2 = appDesc.getInputStream(input2Descriptor); + OutputStream<KV<Object, Object>> output1 = appDesc.getOutputStream(output1Descriptor); + + TableDescriptor tableDescriptor = new TestTableDescriptor("table-id", Arrays.asList("input1"), + (message, store) -> Collections.emptyList()); + Table table = appDesc.getTable(tableDescriptor); + + messageStream2 + .partitionBy(m -> m.key, m -> m.value, mock(KVSerde.class), "p1") + .join(table, mock(StreamTableJoinFunction.class)) + .sendTo(output1); + }, config); + } + + private StreamApplicationDescriptorImpl createStreamGraphWithInvalidStreamTableJoinWithSideInputs() { + /** + * Example stream-table join that is invalid due to disagreement in partition count between the + * stream behind table t and another joined stream. Table t is configured with input2 (16) as + * side-input stream. + * + * join-table t -> output1 (8) + * | + * input1 (64) âââââââââ + * + */ + return new StreamApplicationDescriptorImpl(appDesc -> { + MessageStream<KV<Object, Object>> messageStream1 = appDesc.getInputStream(input1Descriptor); + OutputStream<KV<Object, Object>> output1 = appDesc.getOutputStream(output1Descriptor); + + TableDescriptor tableDescriptor = new TestTableDescriptor("table-id", Arrays.asList("input2"), + (message, store) -> Collections.emptyList()); + Table table = appDesc.getTable(tableDescriptor); + + messageStream1 + .join(table, mock(StreamTableJoinFunction.class)) + .sendTo(output1); }, config); } @@ -307,9 +481,9 @@ public class TestExecutionPlanner { @Test public void testCreateProcessorGraph() { ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); - StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin(); + StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithStreamStreamJoin(); - JobGraph jobGraph = planner.createJobGraph(graphSpec.getConfig(), graphSpec); + JobGraph jobGraph = planner.createJobGraph(graphSpec); assertTrue(jobGraph.getInputStreams().size() == 3); assertTrue(jobGraph.getOutputStreams().size() == 2); assertTrue(jobGraph.getIntermediateStreams().size() == 2); // two streams generated by partitionBy @@ -318,10 +492,10 @@ public class TestExecutionPlanner { @Test public void testFetchExistingStreamPartitions() { ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); - StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin(); - JobGraph jobGraph = planner.createJobGraph(graphSpec.getConfig(), graphSpec); + StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithStreamStreamJoin(); + JobGraph jobGraph = planner.createJobGraph(graphSpec); - ExecutionPlanner.setInputAndOutputStreamPartitionCount(jobGraph, streamManager); + planner.setInputAndOutputStreamPartitionCount(jobGraph); assertTrue(jobGraph.getOrCreateStreamEdge(input1Spec).getPartitionCount() == 64); assertTrue(jobGraph.getOrCreateStreamEdge(input2Spec).getPartitionCount() == 16); assertTrue(jobGraph.getOrCreateStreamEdge(input3Spec).getPartitionCount() == 32); @@ -336,24 +510,74 @@ public class TestExecutionPlanner { @Test public void testCalculateJoinInputPartitions() { ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); - StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin(); - JobGraph jobGraph = planner.createJobGraph(graphSpec.getConfig(), graphSpec); + StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithStreamStreamJoin(); + JobGraph jobGraph = (JobGraph) planner.plan(graphSpec); - ExecutionPlanner.setInputAndOutputStreamPartitionCount(jobGraph, streamManager); - new IntermediateStreamManager(config, graphSpec).calculatePartitions(jobGraph); + // Partitions should be the same as input1 + jobGraph.getIntermediateStreams().forEach(edge -> { + assertEquals(64, edge.getPartitionCount()); + }); + } - // the partitions should be the same as input1 + @Test + public void testCalculateOrderSensitiveJoinInputPartitions() { + // This test ensures that the ExecutionPlanner can handle groups of joined stream edges + // in the correct order. It creates an example stream-stream join application that has + // the following sets of joined streams (notice the order): + // + // a. e1 (16), e2` (?) + // b. e3` (?), e4` (?) + // c. e2` (?), e3` (?) + // + // If processed in the above order, the ExecutionPlanner will fail to assign the partitions + // correctly. + ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); + StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithComplexStreamStreamJoin(); + JobGraph jobGraph = (JobGraph) planner.plan(graphSpec); + + // Partitions should be the same as input1 jobGraph.getIntermediateStreams().forEach(edge -> { assertEquals(64, edge.getPartitionCount()); }); } - @Test(expected = SamzaException.class) - public void testRejectsInvalidJoin() { + @Test + public void testCalculateIntStreamPartitions() { ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); - StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithInvalidJoin(); + StreamApplicationDescriptorImpl graphSpec = createSimpleGraph(); - planner.plan(graphSpec); + JobGraph jobGraph = (JobGraph) planner.plan(graphSpec); + + // Partitions should be the same as input1 + jobGraph.getIntermediateStreams().forEach(edge -> { + assertEquals(64, edge.getPartitionCount()); // max of input1 and output1 + }); + } + + @Test + public void testCalculateInStreamPartitionsBehindTables() { + ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); + StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithStreamTableJoin(); + + JobGraph jobGraph = (JobGraph) planner.plan(graphSpec); + + // Partitions should be the same as input3 + jobGraph.getIntermediateStreams().forEach(edge -> { + assertEquals(32, edge.getPartitionCount()); + }); + } + + @Test + public void testCalculateInStreamPartitionsBehindTablesWithSideInputs() { + ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); + StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithStreamTableJoinWithSideInputs(); + + JobGraph jobGraph = (JobGraph) planner.plan(graphSpec); + + // Partitions should be the same as input1 + jobGraph.getIntermediateStreams().forEach(edge -> { + assertEquals(64, edge.getPartitionCount()); + }); } @Test @@ -366,21 +590,65 @@ public class TestExecutionPlanner { StreamApplicationDescriptorImpl graphSpec = createSimpleGraph(); JobGraph jobGraph = (JobGraph) planner.plan(graphSpec); - // the partitions should be the same as input1 + // Partitions should be the same as input1 jobGraph.getIntermediateStreams().forEach(edge -> { assertTrue(edge.getPartitionCount() == DEFAULT_PARTITIONS); }); } @Test + public void testMaxPartitionLimit() { + int partitionLimit = IntermediateStreamManager.MAX_INFERRED_PARTITIONS; + + ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); + StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { + MessageStream<KV<Object, Object>> input1 = appDesc.getInputStream(input4Descriptor); + OutputStream<KV<Object, Object>> output1 = appDesc.getOutputStream(output1Descriptor); + input1.partitionBy(m -> m.key, m -> m.value, mock(KVSerde.class), "p1").map(kv -> kv).sendTo(output1); + }, config); + + JobGraph jobGraph = (JobGraph) planner.plan(graphSpec); + + // Partitions should be the same as input1 + jobGraph.getIntermediateStreams().forEach(edge -> { + assertEquals(partitionLimit, edge.getPartitionCount()); // max of input1 and output1 + }); + } + + @Test(expected = SamzaException.class) + public void testRejectsInvalidStreamStreamJoin() { + ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); + StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithInvalidStreamStreamJoin(); + + planner.plan(graphSpec); + } + + @Test(expected = SamzaException.class) + public void testRejectsInvalidStreamTableJoin() { + ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); + StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithInvalidStreamTableJoin(); + + planner.plan(graphSpec); + } + + @Test(expected = SamzaException.class) + public void testRejectsInvalidStreamTableJoinWithSideInputs() { + ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); + StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithInvalidStreamTableJoinWithSideInputs(); + + planner.plan(graphSpec); + } + + @Test public void testTriggerIntervalForJoins() { Map<String, String> map = new HashMap<>(config); map.put(JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(), String.valueOf(DEFAULT_PARTITIONS)); Config cfg = new MapConfig(map); ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager); - StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin(); + StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithStreamStreamJoin(); ExecutionPlan plan = planner.plan(graphSpec); + List<JobConfig> jobConfigs = plan.getJobConfigs(); for (JobConfig config : jobConfigs) { System.out.println(config); @@ -450,18 +718,6 @@ public class TestExecutionPlanner { } @Test - public void testCalculateIntStreamPartitions() { - ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); - StreamApplicationDescriptorImpl graphSpec = createSimpleGraph(); - JobGraph jobGraph = (JobGraph) planner.plan(graphSpec); - - // the partitions should be the same as input1 - jobGraph.getIntermediateStreams().forEach(edge -> { - assertEquals(64, edge.getPartitionCount()); // max of input1 and output1 - }); - } - - @Test public void testMaxPartition() { Collection<StreamEdge> edges = new ArrayList<>(); StreamEdge edge = new StreamEdge(input1Spec, false, false, config); @@ -481,25 +737,6 @@ public class TestExecutionPlanner { } @Test - public void testMaxPartitionLimit() throws Exception { - int partitionLimit = IntermediateStreamManager.MAX_INFERRED_PARTITIONS; - - ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); - StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { - MessageStream<KV<Object, Object>> input1 = appDesc.getInputStream(input4Descriptor); - OutputStream<KV<Object, Object>> output1 = appDesc.getOutputStream(output1Descriptor); - input1.partitionBy(m -> m.key, m -> m.value, mock(KVSerde.class), "p1").map(kv -> kv).sendTo(output1); - }, config); - - JobGraph jobGraph = (JobGraph) planner.plan(graphSpec); - - // the partitions should be the same as input1 - jobGraph.getIntermediateStreams().forEach(edge -> { - assertEquals(partitionLimit, edge.getPartitionCount()); // max of input1 and output1 - }); - } - - @Test public void testCreateJobGraphForTaskApplication() { TaskApplicationDescriptorImpl taskAppDesc = mock(TaskApplicationDescriptorImpl.class); // add interemediate streams @@ -537,7 +774,7 @@ public class TestExecutionPlanner { systemDescriptors.forEach(sd -> systemStreamConfigs.putAll(sd.toConfig())); ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); - JobGraph jobGraph = planner.createJobGraph(config, taskAppDesc); + JobGraph jobGraph = planner.createJobGraph(taskAppDesc); assertEquals(1, jobGraph.getJobNodes().size()); assertTrue(jobGraph.getInputStreams().stream().map(edge -> edge.getName()) .filter(streamId -> inputDescriptors.containsKey(streamId)).collect(Collectors.toList()).isEmpty()); @@ -568,7 +805,7 @@ public class TestExecutionPlanner { systemDescriptors.forEach(sd -> systemStreamConfigs.putAll(sd.toConfig())); ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); - JobGraph jobGraph = planner.createJobGraph(config, taskAppDesc); + JobGraph jobGraph = planner.createJobGraph(taskAppDesc); assertEquals(1, jobGraph.getJobNodes().size()); JobNode jobNode = jobGraph.getJobNodes().get(0); assertEquals("test-app", jobNode.getJobName()); @@ -586,4 +823,26 @@ public class TestExecutionPlanner { } } + + private static class TestTableDescriptor extends BaseTableDescriptor implements TableDescriptor { + private final List<String> sideInputs; + private final SideInputsProcessor sideInputsProcessor; + + public TestTableDescriptor(String tableId) { + this(tableId, Collections.emptyList(), null); + } + + public TestTableDescriptor(String tableId, List<String> sideInputs, SideInputsProcessor sideInputsProcessor) { + super(tableId); + this.sideInputs = sideInputs; + this.sideInputsProcessor = sideInputsProcessor; + } + + @Override + public TableSpec getTableSpec() { + validate(); + return new TableSpec(tableId, serde, "dummyTableProviderFactoryClassName", + Collections.emptyMap(), sideInputs, sideInputsProcessor); + } + } } http://git-wip-us.apache.org/repos/asf/samza/blob/d2c9e816/samza-core/src/test/java/org/apache/samza/execution/TestIntermediateStreamManager.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestIntermediateStreamManager.java b/samza-core/src/test/java/org/apache/samza/execution/TestIntermediateStreamManager.java deleted file mode 100644 index bc15709..0000000 --- a/samza-core/src/test/java/org/apache/samza/execution/TestIntermediateStreamManager.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.samza.execution; - -import org.apache.samza.application.StreamApplicationDescriptorImpl; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.mock; - -/** - * Unit tests for {@link IntermediateStreamManager} - */ -public class TestIntermediateStreamManager extends ExecutionPlannerTestBase { - - @Test - public void testCalculateRepartitionJoinTopicPartitions() { - mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionJoinStreamApplication(), mockConfig); - IntermediateStreamManager partitionPlanner = new IntermediateStreamManager(mockConfig, mockStreamAppDesc); - JobGraph mockGraph = new ExecutionPlanner(mockConfig, mock(StreamManager.class)) - .createJobGraph(mockConfig, mockStreamAppDesc); - // set the input stream partitions - mockGraph.getInputStreams().forEach(inEdge -> { - if (inEdge.getStreamSpec().getId().equals(input1Descriptor.getStreamId())) { - inEdge.setPartitionCount(6); - } else if (inEdge.getStreamSpec().getId().equals(input2Descriptor.getStreamId())) { - inEdge.setPartitionCount(5); - } - }); - partitionPlanner.calculatePartitions(mockGraph); - assertEquals(1, mockGraph.getIntermediateStreamEdges().size()); - assertEquals(5, mockGraph.getIntermediateStreamEdges().stream() - .filter(inEdge -> inEdge.getStreamSpec().getId().equals(intermediateInputDescriptor.getStreamId())) - .findFirst().get().getPartitionCount()); - } - - @Test - public void testCalculateRepartitionIntermediateTopicPartitions() { - mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionOnlyStreamApplication(), mockConfig); - IntermediateStreamManager partitionPlanner = new IntermediateStreamManager(mockConfig, mockStreamAppDesc); - JobGraph mockGraph = new ExecutionPlanner(mockConfig, mock(StreamManager.class)) - .createJobGraph(mockConfig, mockStreamAppDesc); - // set the input stream partitions - mockGraph.getInputStreams().forEach(inEdge -> inEdge.setPartitionCount(7)); - partitionPlanner.calculatePartitions(mockGraph); - assertEquals(1, mockGraph.getIntermediateStreamEdges().size()); - assertEquals(7, mockGraph.getIntermediateStreamEdges().stream() - .filter(inEdge -> inEdge.getStreamSpec().getId().equals(intermediateInputDescriptor.getStreamId())) - .findFirst().get().getPartitionCount()); - } - -}
