You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by xi...@apache.org on 2018/03/28 17:25:25 UTC

samza git commit: SAMZA-1627: Watermark broadcast enhancements

Repository: samza
Updated Branches:
  refs/heads/master aff805d07 -> 5431350b7


SAMZA-1627: Watermark broadcast enhancements

Currently each upstream task needs to broadcast to every single partition of intermediate streams in order to aggregate watermarks in the consumers. A better way to do this is to have only one downstream consumer doing the aggregation, and then broadcast to all the partitions. This is safe as we can prove the broadcast watermark message is after all the upstream tasks finished producing the events whose event time are before the watermark. This reduced the full message count from O(n^2) to O(n).

Author: xinyuiscool <xi...@linkedin.com>

Reviewers: Boris S <sb...@gmail.com>

Closes #456 from xinyuiscool/SAMZA-1627


Project: http://git-wip-us.apache.org/repos/asf/samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/5431350b
Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/5431350b
Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/5431350b

Branch: refs/heads/master
Commit: 5431350b7390704e395c947834b01a5f2e76d906
Parents: aff805d
Author: xinyuiscool <xi...@linkedin.com>
Authored: Wed Mar 28 10:25:15 2018 -0700
Committer: xiliu <xi...@linkedin.com>
Committed: Wed Mar 28 10:25:15 2018 -0700

----------------------------------------------------------------------
 .../operators/impl/ControlMessageSender.java    | 38 ++++++++++++++------
 .../samza/operators/impl/EndOfStreamStates.java |  6 +++-
 .../samza/operators/impl/OperatorImpl.java      | 14 ++++++++
 .../samza/operators/impl/WatermarkStates.java   | 12 +++----
 .../impl/TestControlMessageSender.java          | 32 ++++++++++++++++-
 5 files changed, 84 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/samza/blob/5431350b/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java b/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java
index 4afca92..d4782b0 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java
@@ -26,6 +26,7 @@ import org.apache.samza.system.OutgoingMessageEnvelope;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.task.MessageCollector;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -35,7 +36,7 @@ import java.util.concurrent.ConcurrentHashMap;
 
 
 /**
- * This is a helper class to broadcast control messages to each partition of an intermediate stream
+ * This is a helper class to send control messages to an intermediate stream
  */
 class ControlMessageSender {
   private static final Logger LOG = LoggerFactory.getLogger(ControlMessageSender.class);
@@ -48,20 +49,37 @@ class ControlMessageSender {
   }
 
   void send(ControlMessage message, SystemStream systemStream, MessageCollector collector) {
-    Integer partitionCount = PARTITION_COUNT_CACHE.computeIfAbsent(systemStream, ss -> {
+    int partitionCount = getPartitionCount(systemStream);
+    // We pick a partition based on topic hashcode to aggregate the control messages from upstream tasks
+    // After aggregation the task will broadcast the results to other partitions
+    int aggregatePartition = systemStream.getStream().hashCode() % partitionCount;
+
+    LOG.debug(String.format("Send %s message from task %s to %s partition %s for aggregation",
+        MessageType.of(message).name(), message.getTaskName(), systemStream, aggregatePartition));
+
+    OutgoingMessageEnvelope envelopeOut = new OutgoingMessageEnvelope(systemStream, aggregatePartition, null, message);
+    collector.send(envelopeOut);
+  }
+
+  void broadcastToOtherPartitions(ControlMessage message, SystemStreamPartition ssp, MessageCollector collector) {
+    SystemStream systemStream = ssp.getSystemStream();
+    int partitionCount = getPartitionCount(systemStream);
+    int currentPartition = ssp.getPartition().getPartitionId();
+    for (int i = 0; i < partitionCount; i++) {
+      if (i != currentPartition) {
+        OutgoingMessageEnvelope envelopeOut = new OutgoingMessageEnvelope(systemStream, i, null, message);
+        collector.send(envelopeOut);
+      }
+    }
+  }
+
+  private int getPartitionCount(SystemStream systemStream) {
+    return PARTITION_COUNT_CACHE.computeIfAbsent(systemStream, ss -> {
         SystemStreamMetadata metadata = metadataCache.getSystemStreamMetadata(ss, true);
         if (metadata == null) {
           throw new SamzaException("Unable to find metadata for stream " + systemStream);
         }
         return metadata.getSystemStreamPartitionMetadata().size();
       });
-
-    LOG.debug(String.format("Broadcast %s message from task %s to %s with %s partition",
-        MessageType.of(message).name(), message.getTaskName(), systemStream, partitionCount));
-
-    for (int i = 0; i < partitionCount; i++) {
-      OutgoingMessageEnvelope envelopeOut = new OutgoingMessageEnvelope(systemStream, i, null, message);
-      collector.send(envelopeOut);
-    }
   }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5431350b/samza-core/src/main/java/org/apache/samza/operators/impl/EndOfStreamStates.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/EndOfStreamStates.java b/samza-core/src/main/java/org/apache/samza/operators/impl/EndOfStreamStates.java
index a69b234..8c9db61 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/EndOfStreamStates.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/EndOfStreamStates.java
@@ -51,9 +51,13 @@ class EndOfStreamStates {
 
     synchronized void update(String taskName) {
       if (taskName != null) {
+        // aggregate the eos messages
         tasks.add(taskName);
+        isEndOfStream = tasks.size() == expectedTotal;
+      } else {
+        // eos is coming from either source or aggregator task
+        isEndOfStream = true;
       }
-      isEndOfStream = tasks.size() == expectedTotal;
     }
 
     boolean isEndOfStream() {

http://git-wip-us.apache.org/repos/asf/samza/blob/5431350b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
index 7219180..f644bd9 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
@@ -83,6 +83,7 @@ public abstract class OperatorImpl<M, RM> {
   // watermark states
   private WatermarkStates watermarkStates;
   private TaskContext taskContext;
+  private ControlMessageSender controlMessageSender;
 
   /**
    * Initialize this {@link OperatorImpl} and its user-defined functions.
@@ -114,6 +115,7 @@ public abstract class OperatorImpl<M, RM> {
     TaskContextImpl taskContext = (TaskContextImpl) context;
     this.eosStates = (EndOfStreamStates) taskContext.fetchObject(EndOfStreamStates.class.getName());
     this.watermarkStates = (WatermarkStates) taskContext.fetchObject(WatermarkStates.class.getName());
+    this.controlMessageSender = new ControlMessageSender(taskContext.getStreamMetadataCache());
 
     if (taskContext.getJobModel() != null) {
       ContainerModel containerModel = taskContext.getJobModel().getContainers()
@@ -265,6 +267,12 @@ public abstract class OperatorImpl<M, RM> {
     SystemStream stream = ssp.getSystemStream();
     if (eosStates.isEndOfStream(stream)) {
       LOG.info("Input {} reaches the end for task {}", stream.toString(), taskName.getTaskName());
+      if (eos.getTaskName() != null) {
+        // This is the aggregation task, which already received all the eos messages from upstream
+        // broadcast the end-of-stream to all the peer partitions
+        controlMessageSender.broadcastToOtherPartitions(new EndOfStreamMessage(), ssp, collector);
+      }
+      // populate the end-of-stream through the dag
       onEndOfStream(collector, coordinator);
 
       if (eosStates.allEndOfStream()) {
@@ -322,6 +330,12 @@ public abstract class OperatorImpl<M, RM> {
     long watermark = watermarkStates.getWatermark(ssp.getSystemStream());
     if (watermark != WatermarkStates.WATERMARK_NOT_EXIST) {
       LOG.debug("Got watermark {} from stream {}", watermark, ssp.getSystemStream());
+      if (watermarkMessage.getTaskName() != null) {
+        // This is the aggregation task, which already received all the watermark messages from upstream
+        // broadcast the watermark to all the peer partitions
+        controlMessageSender.broadcastToOtherPartitions(new WatermarkMessage(watermark), ssp, collector);
+      }
+      // populate the watermark through the dag
       onWatermark(watermark, collector, coordinator);
     }
   }

http://git-wip-us.apache.org/repos/asf/samza/blob/5431350b/samza-core/src/main/java/org/apache/samza/operators/impl/WatermarkStates.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/WatermarkStates.java b/samza-core/src/main/java/org/apache/samza/operators/impl/WatermarkStates.java
index 0295626..5cc66e2 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/WatermarkStates.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/WatermarkStates.java
@@ -63,12 +63,12 @@ class WatermarkStates {
         }
       }
 
-      /**
-       * Check whether we got all the watermarks.
-       * At a sources, the expectedTotal is 0.
-       * For any intermediate streams, the expectedTotal is the upstream task count.
-       */
-      if (timestamps.size() == expectedTotal) {
+      if (taskName == null) {
+        // we get watermark either from the source or from the aggregator task
+        watermarkTime = Math.max(watermarkTime, timestamp);
+      } else if (timestamps.size() == expectedTotal) {
+        // For any intermediate streams, the expectedTotal is the upstream task count.
+        // Check whether we got all the watermarks, and set the watermark to be the min.
         Optional<Long> min = timestamps.values().stream().min(Long::compare);
         watermarkTime = min.orElse(timestamp);
       }

http://git-wip-us.apache.org/repos/asf/samza/blob/5431350b/samza-core/src/test/java/org/apache/samza/operators/impl/TestControlMessageSender.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/operators/impl/TestControlMessageSender.java b/samza-core/src/test/java/org/apache/samza/operators/impl/TestControlMessageSender.java
index d17d751..9ff9a4f 100644
--- a/samza-core/src/test/java/org/apache/samza/operators/impl/TestControlMessageSender.java
+++ b/samza-core/src/test/java/org/apache/samza/operators/impl/TestControlMessageSender.java
@@ -28,6 +28,7 @@ import org.apache.samza.system.OutgoingMessageEnvelope;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.system.WatermarkMessage;
 import org.apache.samza.task.MessageCollector;
 import org.junit.Test;
@@ -68,6 +69,35 @@ public class TestControlMessageSender {
     ControlMessageSender sender = new ControlMessageSender(metadataCache);
     WatermarkMessage watermark = new WatermarkMessage(System.currentTimeMillis(), "task 0");
     sender.send(watermark, systemStream, collector);
-    assertEquals(partitions.size(), 4);
+    assertEquals(partitions.size(), 1);
+  }
+
+  @Test
+  public void testBroadcast() {
+    SystemStreamMetadata metadata = mock(SystemStreamMetadata.class);
+    Map<Partition, SystemStreamMetadata.SystemStreamPartitionMetadata> partitionMetadata = new HashMap<>();
+    partitionMetadata.put(new Partition(0), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(1), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(2), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(3), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    when(metadata.getSystemStreamPartitionMetadata()).thenReturn(partitionMetadata);
+    StreamMetadataCache metadataCache = mock(StreamMetadataCache.class);
+    when(metadataCache.getSystemStreamMetadata(anyObject(), anyBoolean())).thenReturn(metadata);
+
+    SystemStream systemStream = new SystemStream("test-system", "test-stream");
+    Set<Integer> partitions = new HashSet<>();
+    MessageCollector collector = mock(MessageCollector.class);
+    doAnswer(invocation -> {
+        OutgoingMessageEnvelope envelope = (OutgoingMessageEnvelope) invocation.getArguments()[0];
+        partitions.add((Integer) envelope.getPartitionKey());
+        assertEquals(envelope.getSystemStream(), systemStream);
+        return null;
+      }).when(collector).send(any());
+
+    ControlMessageSender sender = new ControlMessageSender(metadataCache);
+    WatermarkMessage watermark = new WatermarkMessage(System.currentTimeMillis(), "task 0");
+    SystemStreamPartition ssp = new SystemStreamPartition(systemStream, new Partition(0));
+    sender.broadcastToOtherPartitions(watermark, ssp, collector);
+    assertEquals(partitions.size(), 3);
   }
 }