Repository: samza
Updated Branches:
  refs/heads/master aff805d07 -> 5431350b7


SAMZA-1627: Watermark broadcast enhancements

Currently each upstream task needs to broadcast to every single partition of 
intermediate streams in order to aggregate watermarks in the consumers. A 
better way to do this is to have only one downstream consumer doing the 
aggregation, and then broadcast to all the partitions. This is safe as we can 
prove the broadcast watermark message is after all the upstream tasks finished 
producing the events whose event time are before the watermark. This reduced 
the full message count from O(n^2) to O(n).

Author: xinyuiscool <[email protected]>

Reviewers: Boris S <[email protected]>

Closes #456 from xinyuiscool/SAMZA-1627


Project: http://git-wip-us.apache.org/repos/asf/samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/5431350b
Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/5431350b
Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/5431350b

Branch: refs/heads/master
Commit: 5431350b7390704e395c947834b01a5f2e76d906
Parents: aff805d
Author: xinyuiscool <[email protected]>
Authored: Wed Mar 28 10:25:15 2018 -0700
Committer: xiliu <[email protected]>
Committed: Wed Mar 28 10:25:15 2018 -0700

----------------------------------------------------------------------
 .../operators/impl/ControlMessageSender.java    | 38 ++++++++++++++------
 .../samza/operators/impl/EndOfStreamStates.java |  6 +++-
 .../samza/operators/impl/OperatorImpl.java      | 14 ++++++++
 .../samza/operators/impl/WatermarkStates.java   | 12 +++----
 .../impl/TestControlMessageSender.java          | 32 ++++++++++++++++-
 5 files changed, 84 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/samza/blob/5431350b/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java
 
b/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java
index 4afca92..d4782b0 100644
--- 
a/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java
+++ 
b/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java
@@ -26,6 +26,7 @@ import org.apache.samza.system.OutgoingMessageEnvelope;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.task.MessageCollector;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -35,7 +36,7 @@ import java.util.concurrent.ConcurrentHashMap;
 
 
 /**
- * This is a helper class to broadcast control messages to each partition of 
an intermediate stream
+ * This is a helper class to send control messages to an intermediate stream
  */
 class ControlMessageSender {
   private static final Logger LOG = 
LoggerFactory.getLogger(ControlMessageSender.class);
@@ -48,20 +49,37 @@ class ControlMessageSender {
   }
 
   void send(ControlMessage message, SystemStream systemStream, 
MessageCollector collector) {
-    Integer partitionCount = 
PARTITION_COUNT_CACHE.computeIfAbsent(systemStream, ss -> {
+    int partitionCount = getPartitionCount(systemStream);
+    // We pick a partition based on topic hashcode to aggregate the control 
messages from upstream tasks
+    // After aggregation the task will broadcast the results to other 
partitions
+    int aggregatePartition = systemStream.getStream().hashCode() % 
partitionCount;
+
+    LOG.debug(String.format("Send %s message from task %s to %s partition %s 
for aggregation",
+        MessageType.of(message).name(), message.getTaskName(), systemStream, 
aggregatePartition));
+
+    OutgoingMessageEnvelope envelopeOut = new 
OutgoingMessageEnvelope(systemStream, aggregatePartition, null, message);
+    collector.send(envelopeOut);
+  }
+
+  void broadcastToOtherPartitions(ControlMessage message, 
SystemStreamPartition ssp, MessageCollector collector) {
+    SystemStream systemStream = ssp.getSystemStream();
+    int partitionCount = getPartitionCount(systemStream);
+    int currentPartition = ssp.getPartition().getPartitionId();
+    for (int i = 0; i < partitionCount; i++) {
+      if (i != currentPartition) {
+        OutgoingMessageEnvelope envelopeOut = new 
OutgoingMessageEnvelope(systemStream, i, null, message);
+        collector.send(envelopeOut);
+      }
+    }
+  }
+
+  private int getPartitionCount(SystemStream systemStream) {
+    return PARTITION_COUNT_CACHE.computeIfAbsent(systemStream, ss -> {
         SystemStreamMetadata metadata = 
metadataCache.getSystemStreamMetadata(ss, true);
         if (metadata == null) {
           throw new SamzaException("Unable to find metadata for stream " + 
systemStream);
         }
         return metadata.getSystemStreamPartitionMetadata().size();
       });
-
-    LOG.debug(String.format("Broadcast %s message from task %s to %s with %s 
partition",
-        MessageType.of(message).name(), message.getTaskName(), systemStream, 
partitionCount));
-
-    for (int i = 0; i < partitionCount; i++) {
-      OutgoingMessageEnvelope envelopeOut = new 
OutgoingMessageEnvelope(systemStream, i, null, message);
-      collector.send(envelopeOut);
-    }
   }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5431350b/samza-core/src/main/java/org/apache/samza/operators/impl/EndOfStreamStates.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/java/org/apache/samza/operators/impl/EndOfStreamStates.java
 
b/samza-core/src/main/java/org/apache/samza/operators/impl/EndOfStreamStates.java
index a69b234..8c9db61 100644
--- 
a/samza-core/src/main/java/org/apache/samza/operators/impl/EndOfStreamStates.java
+++ 
b/samza-core/src/main/java/org/apache/samza/operators/impl/EndOfStreamStates.java
@@ -51,9 +51,13 @@ class EndOfStreamStates {
 
     synchronized void update(String taskName) {
       if (taskName != null) {
+        // aggregate the eos messages
         tasks.add(taskName);
+        isEndOfStream = tasks.size() == expectedTotal;
+      } else {
+        // eos is coming from either source or aggregator task
+        isEndOfStream = true;
       }
-      isEndOfStream = tasks.size() == expectedTotal;
     }
 
     boolean isEndOfStream() {

http://git-wip-us.apache.org/repos/asf/samza/blob/5431350b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java 
b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
index 7219180..f644bd9 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
@@ -83,6 +83,7 @@ public abstract class OperatorImpl<M, RM> {
   // watermark states
   private WatermarkStates watermarkStates;
   private TaskContext taskContext;
+  private ControlMessageSender controlMessageSender;
 
   /**
    * Initialize this {@link OperatorImpl} and its user-defined functions.
@@ -114,6 +115,7 @@ public abstract class OperatorImpl<M, RM> {
     TaskContextImpl taskContext = (TaskContextImpl) context;
     this.eosStates = (EndOfStreamStates) 
taskContext.fetchObject(EndOfStreamStates.class.getName());
     this.watermarkStates = (WatermarkStates) 
taskContext.fetchObject(WatermarkStates.class.getName());
+    this.controlMessageSender = new 
ControlMessageSender(taskContext.getStreamMetadataCache());
 
     if (taskContext.getJobModel() != null) {
       ContainerModel containerModel = taskContext.getJobModel().getContainers()
@@ -265,6 +267,12 @@ public abstract class OperatorImpl<M, RM> {
     SystemStream stream = ssp.getSystemStream();
     if (eosStates.isEndOfStream(stream)) {
       LOG.info("Input {} reaches the end for task {}", stream.toString(), 
taskName.getTaskName());
+      if (eos.getTaskName() != null) {
+        // This is the aggregation task, which already received all the eos 
messages from upstream
+        // broadcast the end-of-stream to all the peer partitions
+        controlMessageSender.broadcastToOtherPartitions(new 
EndOfStreamMessage(), ssp, collector);
+      }
+      // populate the end-of-stream through the dag
       onEndOfStream(collector, coordinator);
 
       if (eosStates.allEndOfStream()) {
@@ -322,6 +330,12 @@ public abstract class OperatorImpl<M, RM> {
     long watermark = watermarkStates.getWatermark(ssp.getSystemStream());
     if (watermark != WatermarkStates.WATERMARK_NOT_EXIST) {
       LOG.debug("Got watermark {} from stream {}", watermark, 
ssp.getSystemStream());
+      if (watermarkMessage.getTaskName() != null) {
+        // This is the aggregation task, which already received all the 
watermark messages from upstream
+        // broadcast the watermark to all the peer partitions
+        controlMessageSender.broadcastToOtherPartitions(new 
WatermarkMessage(watermark), ssp, collector);
+      }
+      // populate the watermark through the dag
       onWatermark(watermark, collector, coordinator);
     }
   }

http://git-wip-us.apache.org/repos/asf/samza/blob/5431350b/samza-core/src/main/java/org/apache/samza/operators/impl/WatermarkStates.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/java/org/apache/samza/operators/impl/WatermarkStates.java 
b/samza-core/src/main/java/org/apache/samza/operators/impl/WatermarkStates.java
index 0295626..5cc66e2 100644
--- 
a/samza-core/src/main/java/org/apache/samza/operators/impl/WatermarkStates.java
+++ 
b/samza-core/src/main/java/org/apache/samza/operators/impl/WatermarkStates.java
@@ -63,12 +63,12 @@ class WatermarkStates {
         }
       }
 
-      /**
-       * Check whether we got all the watermarks.
-       * At a sources, the expectedTotal is 0.
-       * For any intermediate streams, the expectedTotal is the upstream task 
count.
-       */
-      if (timestamps.size() == expectedTotal) {
+      if (taskName == null) {
+        // we get watermark either from the source or from the aggregator task
+        watermarkTime = Math.max(watermarkTime, timestamp);
+      } else if (timestamps.size() == expectedTotal) {
+        // For any intermediate streams, the expectedTotal is the upstream 
task count.
+        // Check whether we got all the watermarks, and set the watermark to 
be the min.
         Optional<Long> min = timestamps.values().stream().min(Long::compare);
         watermarkTime = min.orElse(timestamp);
       }

http://git-wip-us.apache.org/repos/asf/samza/blob/5431350b/samza-core/src/test/java/org/apache/samza/operators/impl/TestControlMessageSender.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/java/org/apache/samza/operators/impl/TestControlMessageSender.java
 
b/samza-core/src/test/java/org/apache/samza/operators/impl/TestControlMessageSender.java
index d17d751..9ff9a4f 100644
--- 
a/samza-core/src/test/java/org/apache/samza/operators/impl/TestControlMessageSender.java
+++ 
b/samza-core/src/test/java/org/apache/samza/operators/impl/TestControlMessageSender.java
@@ -28,6 +28,7 @@ import org.apache.samza.system.OutgoingMessageEnvelope;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.system.WatermarkMessage;
 import org.apache.samza.task.MessageCollector;
 import org.junit.Test;
@@ -68,6 +69,35 @@ public class TestControlMessageSender {
     ControlMessageSender sender = new ControlMessageSender(metadataCache);
     WatermarkMessage watermark = new 
WatermarkMessage(System.currentTimeMillis(), "task 0");
     sender.send(watermark, systemStream, collector);
-    assertEquals(partitions.size(), 4);
+    assertEquals(partitions.size(), 1);
+  }
+
+  @Test
+  public void testBroadcast() {
+    SystemStreamMetadata metadata = mock(SystemStreamMetadata.class);
+    Map<Partition, SystemStreamMetadata.SystemStreamPartitionMetadata> 
partitionMetadata = new HashMap<>();
+    partitionMetadata.put(new Partition(0), 
mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(1), 
mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(2), 
mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(3), 
mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    
when(metadata.getSystemStreamPartitionMetadata()).thenReturn(partitionMetadata);
+    StreamMetadataCache metadataCache = mock(StreamMetadataCache.class);
+    when(metadataCache.getSystemStreamMetadata(anyObject(), 
anyBoolean())).thenReturn(metadata);
+
+    SystemStream systemStream = new SystemStream("test-system", "test-stream");
+    Set<Integer> partitions = new HashSet<>();
+    MessageCollector collector = mock(MessageCollector.class);
+    doAnswer(invocation -> {
+        OutgoingMessageEnvelope envelope = (OutgoingMessageEnvelope) 
invocation.getArguments()[0];
+        partitions.add((Integer) envelope.getPartitionKey());
+        assertEquals(envelope.getSystemStream(), systemStream);
+        return null;
+      }).when(collector).send(any());
+
+    ControlMessageSender sender = new ControlMessageSender(metadataCache);
+    WatermarkMessage watermark = new 
WatermarkMessage(System.currentTimeMillis(), "task 0");
+    SystemStreamPartition ssp = new SystemStreamPartition(systemStream, new 
Partition(0));
+    sender.broadcastToOtherPartitions(watermark, ssp, collector);
+    assertEquals(partitions.size(), 3);
   }
 }

Reply via email to