You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by av...@apache.org on 2017/06/22 12:33:40 UTC

[2/3] beam git commit: [BEAM-2359] Fix watermark broadcasting to executors in Spark runner

[BEAM-2359] Fix watermark broadcasting to executors in Spark runner


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/20820fa5
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/20820fa5
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/20820fa5

Branch: refs/heads/master
Commit: 20820fa5477ffcdd4a9ef2e9340353ed3c5691a9
Parents: b3099bb
Author: Aviem Zur <av...@gmail.com>
Authored: Mon Jun 12 17:04:00 2017 +0300
Committer: Aviem Zur <av...@gmail.com>
Committed: Thu Jun 22 14:51:02 2017 +0300

----------------------------------------------------------------------
 .../apache/beam/runners/spark/SparkRunner.java  |   2 +-
 .../beam/runners/spark/TestSparkRunner.java     |   2 +-
 .../SparkGroupAlsoByWindowViaWindowSet.java     |   6 +-
 .../spark/stateful/SparkTimerInternals.java     |  18 ++-
 .../spark/util/GlobalWatermarkHolder.java       | 127 ++++++++++++++-----
 .../spark/GlobalWatermarkHolderTest.java        |  18 +--
 6 files changed, 120 insertions(+), 53 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/20820fa5/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
index d008718..595521f 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
@@ -171,7 +171,7 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> {
       }
 
       // register Watermarks listener to broadcast the advanced WMs.
-      jssc.addStreamingListener(new JavaStreamingListenerWrapper(new WatermarksListener(jssc)));
+      jssc.addStreamingListener(new JavaStreamingListenerWrapper(new WatermarksListener()));
 
       // The reason we call initAccumulators here even though it is called in
       // SparkRunnerStreamingContextFactory is because the factory is not called when resuming

http://git-wip-us.apache.org/repos/asf/beam/blob/20820fa5/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
index eccee57..a13a3b1 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
@@ -169,7 +169,7 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> {
     result.waitUntilFinish(Duration.millis(batchDurationMillis));
     do {
       SparkTimerInternals sparkTimerInternals =
-          SparkTimerInternals.global(GlobalWatermarkHolder.get());
+          SparkTimerInternals.global(GlobalWatermarkHolder.get(batchDurationMillis));
       sparkTimerInternals.advanceWatermark();
       globalWatermark = sparkTimerInternals.currentInputWatermarkTime();
       // let another batch-interval period of execution, just to reason about WM propagation.

http://git-wip-us.apache.org/repos/asf/beam/blob/20820fa5/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
index be4f3f6..1385e07 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
@@ -104,13 +104,15 @@ public class SparkGroupAlsoByWindowViaWindowSet {
 
   public static <K, InputT, W extends BoundedWindow>
       JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> groupAlsoByWindow(
-          JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> inputDStream,
+          final JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> inputDStream,
           final Coder<K> keyCoder,
           final Coder<WindowedValue<InputT>> wvCoder,
           final WindowingStrategy<?, W> windowingStrategy,
           final SparkRuntimeContext runtimeContext,
           final List<Integer> sourceIds) {
 
+    final long batchDurationMillis =
+        runtimeContext.getPipelineOptions().as(SparkPipelineOptions.class).getBatchIntervalMillis();
     final IterableCoder<WindowedValue<InputT>> itrWvCoder = IterableCoder.of(wvCoder);
     final Coder<InputT> iCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder();
     final Coder<? extends BoundedWindow> wCoder =
@@ -239,7 +241,7 @@ public class SparkGroupAlsoByWindowViaWindowSet {
 
                       SparkStateInternals<K> stateInternals;
                       SparkTimerInternals timerInternals = SparkTimerInternals.forStreamFromSources(
-                          sourceIds, GlobalWatermarkHolder.get());
+                          sourceIds, GlobalWatermarkHolder.get(batchDurationMillis));
                       // get state(internals) per key.
                       if (prevStateAndTimersOpt.isEmpty()) {
                         // no previous state.

http://git-wip-us.apache.org/repos/asf/beam/blob/20820fa5/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
index 107915f..a68da55 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
@@ -34,7 +34,6 @@ import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.runners.spark.util.GlobalWatermarkHolder.SparkWatermarks;
 import org.apache.beam.sdk.state.TimeDomain;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.spark.broadcast.Broadcast;
 import org.joda.time.Instant;
 
 
@@ -58,10 +57,10 @@ public class SparkTimerInternals implements TimerInternals {
   /** Build the {@link TimerInternals} according to the feeding streams. */
   public static SparkTimerInternals forStreamFromSources(
       List<Integer> sourceIds,
-      @Nullable Broadcast<Map<Integer, SparkWatermarks>> broadcast) {
-    // if broadcast is invalid for the specific ids, use defaults.
-    if (broadcast == null || broadcast.getValue().isEmpty()
-        || Collections.disjoint(sourceIds, broadcast.getValue().keySet())) {
+      Map<Integer, SparkWatermarks> watermarks) {
+    // if watermarks are invalid for the specific ids, use defaults.
+    if (watermarks == null || watermarks.isEmpty()
+        || Collections.disjoint(sourceIds, watermarks.keySet())) {
       return new SparkTimerInternals(
           BoundedWindow.TIMESTAMP_MIN_VALUE, BoundedWindow.TIMESTAMP_MIN_VALUE, new Instant(0));
     }
@@ -71,7 +70,7 @@ public class SparkTimerInternals implements TimerInternals {
     // synchronized processing time should clearly be synchronized.
     Instant synchronizedProcessingTime = null;
     for (Integer sourceId: sourceIds) {
-      SparkWatermarks sparkWatermarks = broadcast.getValue().get(sourceId);
+      SparkWatermarks sparkWatermarks = watermarks.get(sourceId);
       if (sparkWatermarks != null) {
         // keep slowest WMs.
         slowestLowWatermark = slowestLowWatermark.isBefore(sparkWatermarks.getLowWatermark())
@@ -94,10 +93,9 @@ public class SparkTimerInternals implements TimerInternals {
   }
 
   /** Build a global {@link TimerInternals} for all feeding streams.*/
-  public static SparkTimerInternals global(
-      @Nullable Broadcast<Map<Integer, SparkWatermarks>> broadcast) {
-    return broadcast == null ? forStreamFromSources(Collections.<Integer>emptyList(), null)
-        : forStreamFromSources(Lists.newArrayList(broadcast.getValue().keySet()), broadcast);
+  public static SparkTimerInternals global(Map<Integer, SparkWatermarks> watermarks) {
+    return watermarks == null ? forStreamFromSources(Collections.<Integer>emptyList(), null)
+        : forStreamFromSources(Lists.newArrayList(watermarks.keySet()), watermarks);
   }
 
   Collection<TimerData> getTimers() {

http://git-wip-us.apache.org/repos/asf/beam/blob/20820fa5/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java
index 8b384d8..2cb6f26 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java
@@ -21,31 +21,43 @@ package org.apache.beam.runners.spark.util;
 import static com.google.common.base.Preconditions.checkState;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
+import com.google.common.collect.Maps;
 import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Queue;
 import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nonnull;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SparkEnv;
 import org.apache.spark.broadcast.Broadcast;
-import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.BlockResult;
+import org.apache.spark.storage.BlockStore;
+import org.apache.spark.storage.StorageLevel;
 import org.apache.spark.streaming.api.java.JavaStreamingListener;
 import org.apache.spark.streaming.api.java.JavaStreamingListenerBatchCompleted;
 import org.joda.time.Instant;
-
+import scala.Option;
 
 /**
- * A {@link Broadcast} variable to hold the global watermarks for a micro-batch.
+ * A {@link BlockStore} variable to hold the global watermarks for a micro-batch.
  *
  * <p>For each source, holds a queue for the watermarks of each micro-batch that was read,
  * and advances the watermarks according to the queue (first-in-first-out).
  */
 public class GlobalWatermarkHolder {
-  // the broadcast is broadcasted to the workers.
-  private static volatile Broadcast<Map<Integer, SparkWatermarks>> broadcast = null;
-  // this should only live in the driver so transient.
-  private static final transient Map<Integer, Queue<SparkWatermarks>> sourceTimes = new HashMap<>();
+  private static final Map<Integer, Queue<SparkWatermarks>> sourceTimes = new HashMap<>();
+  private static final BlockId WATERMARKS_BLOCK_ID = BlockId.apply("broadcast_0WATERMARKS");
+
+  private static volatile Map<Integer, SparkWatermarks> driverWatermarks = null;
+  private static volatile LoadingCache<String, Map<Integer, SparkWatermarks>> watermarkCache = null;
 
   public static void add(int sourceId, SparkWatermarks sparkWatermarks) {
     Queue<SparkWatermarks> timesQueue = sourceTimes.get(sourceId);
@@ -71,22 +83,48 @@ public class GlobalWatermarkHolder {
    * Returns the {@link Broadcast} containing the {@link SparkWatermarks} mapped
    * to their sources.
    */
-  public static Broadcast<Map<Integer, SparkWatermarks>> get() {
-    return broadcast;
+  @SuppressWarnings("unchecked")
+  public static Map<Integer, SparkWatermarks> get(Long cacheInterval) {
+    if (driverWatermarks != null) {
+      // if we are executing in local mode simply return the local values.
+      return driverWatermarks;
+    } else {
+      if (watermarkCache == null) {
+        initWatermarkCache(cacheInterval);
+      }
+      try {
+        return watermarkCache.get("SINGLETON");
+      } catch (ExecutionException e) {
+        throw new RuntimeException(e);
+      }
+    }
+  }
+
+  private static synchronized void initWatermarkCache(Long batchDuration) {
+    if (watermarkCache == null) {
+      watermarkCache =
+          CacheBuilder.newBuilder()
+              // expire watermarks every half batch duration to ensure they update in every batch.
+              .expireAfterWrite(batchDuration / 2, TimeUnit.MILLISECONDS)
+              .build(new WatermarksLoader());
+    }
   }
 
   /**
    * Advances the watermarks to the next-in-line watermarks.
    * SparkWatermarks are monotonically increasing.
    */
-  public static void advance(JavaSparkContext jsc) {
-    synchronized (GlobalWatermarkHolder.class){
+  @SuppressWarnings("unchecked")
+  public static void advance() {
+    synchronized (GlobalWatermarkHolder.class) {
+      BlockManager blockManager = SparkEnv.get().blockManager();
+
       if (sourceTimes.isEmpty()) {
         return;
       }
 
       // update all sources' watermarks into the new broadcast.
-      Map<Integer, SparkWatermarks> newBroadcast = new HashMap<>();
+      Map<Integer, SparkWatermarks> newValues = new HashMap<>();
 
       for (Map.Entry<Integer, Queue<SparkWatermarks>> en: sourceTimes.entrySet()) {
         if (en.getValue().isEmpty()) {
@@ -99,8 +137,22 @@ public class GlobalWatermarkHolder {
         Instant currentLowWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE;
         Instant currentHighWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE;
         Instant currentSynchronizedProcessingTime = BoundedWindow.TIMESTAMP_MIN_VALUE;
-        if (broadcast != null && broadcast.getValue().containsKey(sourceId)) {
-          SparkWatermarks currentTimes = broadcast.getValue().get(sourceId);
+
+        Option<BlockResult> currentOption = blockManager.getRemote(WATERMARKS_BLOCK_ID);
+        Map<Integer, SparkWatermarks> current;
+        if (currentOption.isDefined()) {
+          current = (Map<Integer, SparkWatermarks>) currentOption.get().data().next();
+        } else {
+          current = Maps.newHashMap();
+          blockManager.putSingle(
+              WATERMARKS_BLOCK_ID,
+              current,
+              StorageLevel.MEMORY_ONLY(),
+              true);
+        }
+
+        if (current.containsKey(sourceId)) {
+          SparkWatermarks currentTimes = current.get(sourceId);
           currentLowWatermark = currentTimes.getLowWatermark();
           currentHighWatermark = currentTimes.getHighWatermark();
           currentSynchronizedProcessingTime = currentTimes.getSynchronizedProcessingTime();
@@ -119,20 +171,21 @@ public class GlobalWatermarkHolder {
                 nextLowWatermark, nextHighWatermark));
         checkState(nextSynchronizedProcessingTime.isAfter(currentSynchronizedProcessingTime),
             "Synchronized processing time must advance.");
-        newBroadcast.put(
+        newValues.put(
             sourceId,
             new SparkWatermarks(
                 nextLowWatermark, nextHighWatermark, nextSynchronizedProcessingTime));
       }
 
       // update the watermarks broadcast only if something has changed.
-      if (!newBroadcast.isEmpty()) {
-        if (broadcast != null) {
-          // for now this is blocking, we could make this asynchronous
-          // but it could slow down WM propagation.
-          broadcast.destroy();
-        }
-        broadcast = jsc.broadcast(newBroadcast);
+      if (!newValues.isEmpty()) {
+        driverWatermarks = newValues;
+        blockManager.removeBlock(WATERMARKS_BLOCK_ID, true);
+        blockManager.putSingle(
+            WATERMARKS_BLOCK_ID,
+            newValues,
+            StorageLevel.MEMORY_ONLY(),
+            true);
       }
     }
   }
@@ -140,7 +193,12 @@ public class GlobalWatermarkHolder {
   @VisibleForTesting
   public static synchronized void clear() {
     sourceTimes.clear();
-    broadcast = null;
+    driverWatermarks = null;
+    SparkEnv sparkEnv = SparkEnv.get();
+    if (sparkEnv != null) {
+      BlockManager blockManager = sparkEnv.blockManager();
+      blockManager.removeBlock(WATERMARKS_BLOCK_ID, true);
+    }
   }
 
   /**
@@ -185,15 +243,24 @@ public class GlobalWatermarkHolder {
 
   /** Advance the WMs onBatchCompleted event. */
   public static class WatermarksListener extends JavaStreamingListener {
-    private final JavaStreamingContext jssc;
-
-    public WatermarksListener(JavaStreamingContext jssc) {
-      this.jssc = jssc;
+    @Override
+    public void onBatchCompleted(JavaStreamingListenerBatchCompleted batchCompleted) {
+      GlobalWatermarkHolder.advance();
     }
+  }
+
+  private static class WatermarksLoader extends CacheLoader<String, Map<Integer, SparkWatermarks>> {
 
+    @SuppressWarnings("unchecked")
     @Override
-    public void onBatchCompleted(JavaStreamingListenerBatchCompleted batchCompleted) {
-      GlobalWatermarkHolder.advance(jssc.sparkContext());
+    public Map<Integer, SparkWatermarks> load(@Nonnull String key) throws Exception {
+      Option<BlockResult> blockResultOption =
+          SparkEnv.get().blockManager().getRemote(WATERMARKS_BLOCK_ID);
+      if (blockResultOption.isDefined()) {
+        return (Map<Integer, SparkWatermarks>) blockResultOption.get().data().next();
+      } else {
+        return Maps.newHashMap();
+      }
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/20820fa5/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java
index 47a6e3f..1708123 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java
@@ -65,17 +65,17 @@ public class GlobalWatermarkHolderTest {
             instant.plus(Duration.millis(5)),
             instant.plus(Duration.millis(5)),
             instant));
-    GlobalWatermarkHolder.advance(jsc);
+    GlobalWatermarkHolder.advance();
     // low < high.
     GlobalWatermarkHolder.add(1,
         new SparkWatermarks(
             instant.plus(Duration.millis(10)),
             instant.plus(Duration.millis(15)),
             instant.plus(Duration.millis(100))));
-    GlobalWatermarkHolder.advance(jsc);
+    GlobalWatermarkHolder.advance();
 
     // assert watermarks in Broadcast.
-    SparkWatermarks currentWatermarks = GlobalWatermarkHolder.get().getValue().get(1);
+    SparkWatermarks currentWatermarks = GlobalWatermarkHolder.get(0L).get(1);
     assertThat(currentWatermarks.getLowWatermark(), equalTo(instant.plus(Duration.millis(10))));
     assertThat(currentWatermarks.getHighWatermark(), equalTo(instant.plus(Duration.millis(15))));
     assertThat(currentWatermarks.getSynchronizedProcessingTime(),
@@ -93,7 +93,7 @@ public class GlobalWatermarkHolderTest {
             instant.plus(Duration.millis(25)),
             instant.plus(Duration.millis(20)),
             instant.plus(Duration.millis(200))));
-    GlobalWatermarkHolder.advance(jsc);
+    GlobalWatermarkHolder.advance();
   }
 
   @Test
@@ -106,7 +106,7 @@ public class GlobalWatermarkHolderTest {
             instant.plus(Duration.millis(5)),
             instant.plus(Duration.millis(10)),
             instant));
-    GlobalWatermarkHolder.advance(jsc);
+    GlobalWatermarkHolder.advance();
 
     thrown.expect(IllegalStateException.class);
     thrown.expectMessage("Synchronized processing time must advance.");
@@ -117,7 +117,7 @@ public class GlobalWatermarkHolderTest {
             instant.plus(Duration.millis(5)),
             instant.plus(Duration.millis(10)),
             instant));
-    GlobalWatermarkHolder.advance(jsc);
+    GlobalWatermarkHolder.advance();
   }
 
   @Test
@@ -136,15 +136,15 @@ public class GlobalWatermarkHolderTest {
             instant.plus(Duration.millis(6)),
             instant));
 
-    GlobalWatermarkHolder.advance(jsc);
+    GlobalWatermarkHolder.advance();
 
     // assert watermarks for source 1.
-    SparkWatermarks watermarksForSource1 = GlobalWatermarkHolder.get().getValue().get(1);
+    SparkWatermarks watermarksForSource1 = GlobalWatermarkHolder.get(0L).get(1);
     assertThat(watermarksForSource1.getLowWatermark(), equalTo(instant.plus(Duration.millis(5))));
     assertThat(watermarksForSource1.getHighWatermark(), equalTo(instant.plus(Duration.millis(10))));
 
     // assert watermarks for source 2.
-    SparkWatermarks watermarksForSource2 = GlobalWatermarkHolder.get().getValue().get(2);
+    SparkWatermarks watermarksForSource2 = GlobalWatermarkHolder.get(0L).get(2);
     assertThat(watermarksForSource2.getLowWatermark(), equalTo(instant.plus(Duration.millis(3))));
     assertThat(watermarksForSource2.getHighWatermark(), equalTo(instant.plus(Duration.millis(6))));
   }