You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2022/04/27 18:28:40 UTC

[beam] branch master updated: [BEAM-13015, BEAM-14184] Address unbounded number of messages being written to DirectStreamObserver before isReady is checked (#17358)

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new adfa113a640 [BEAM-13015, BEAM-14184] Address unbounded number of messages being written to DirectStreamObserver before isReady is checked (#17358)
adfa113a640 is described below

commit adfa113a6402b76571d746f8357879bd66fff4d7
Author: Luke Cwik <lc...@google.com>
AuthorDate: Wed Apr 27 11:28:30 2022 -0700

    [BEAM-13015, BEAM-14184] Address unbounded number of messages being written to DirectStreamObserver before isReady is checked (#17358)
    
    * [BEAM-13015] Address unbounded number of messages being written to DirectStreamObserver before isReady is checked
    
    * fixup! Address PR comments
    
    * fixup! Address PR comments
---
 .../beam/sdk/fn/stream/DirectStreamObserver.java   | 82 ++++++++++++----------
 .../sdk/fn/stream/DirectStreamObserverTest.java    | 53 +++++++++++++-
 2 files changed, 98 insertions(+), 37 deletions(-)

diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DirectStreamObserver.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DirectStreamObserver.java
index 934b286c2d0..42b816ed67d 100644
--- a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DirectStreamObserver.java
+++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DirectStreamObserver.java
@@ -20,7 +20,6 @@ package org.apache.beam.sdk.fn.stream;
 import java.util.concurrent.Phaser;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
-import java.util.concurrent.atomic.AtomicInteger;
 import javax.annotation.concurrent.ThreadSafe;
 import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.stub.CallStreamObserver;
 import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.stub.StreamObserver;
@@ -43,9 +42,17 @@ public final class DirectStreamObserver<T> implements StreamObserver<T> {
 
   private final Phaser phaser;
   private final CallStreamObserver<T> outboundObserver;
+
+  /**
+   * Controls the number of messages that will be sent before isReady is invoked for the following
+   * message. For example, maxMessagesBeforeCheck = 0, would mean to check isReady for each message
+   * while maxMessagesBeforeCheck = 10, would mean that you are willing to send 10 messages and then
+   * check isReady before the 11th message is sent.
+   */
   private final int maxMessagesBeforeCheck;
 
-  private AtomicInteger numMessages = new AtomicInteger();
+  private final Object lock = new Object();
+  private int numMessages = -1;
 
   public DirectStreamObserver(Phaser phaser, CallStreamObserver<T> outboundObserver) {
     this(phaser, outboundObserver, DEFAULT_MAX_MESSAGES_BEFORE_CHECK);
@@ -60,55 +67,58 @@ public final class DirectStreamObserver<T> implements StreamObserver<T> {
 
   @Override
   public void onNext(T value) {
-    if (maxMessagesBeforeCheck <= 1
-        || numMessages.incrementAndGet() % maxMessagesBeforeCheck == 0) {
-      int waitTime = 1;
-      int totalTimeWaited = 0;
-      int phase = phaser.getPhase();
-      while (!outboundObserver.isReady()) {
-        try {
-          phaser.awaitAdvanceInterruptibly(phase, waitTime, TimeUnit.SECONDS);
-        } catch (TimeoutException e) {
-          totalTimeWaited += waitTime;
-          waitTime = waitTime * 2;
-        } catch (InterruptedException e) {
-          Thread.currentThread().interrupt();
-          throw new RuntimeException(e);
+    synchronized (lock) {
+      if (++numMessages >= maxMessagesBeforeCheck) {
+        numMessages = 0;
+        int waitTime = 1;
+        int totalTimeWaited = 0;
+        int phase = phaser.getPhase();
+        // Record the initial phase in case we are in the inbound gRPC thread where the phase won't
+        // advance.
+        int initialPhase = phase;
+        while (!outboundObserver.isReady()) {
+          try {
+            phase = phaser.awaitAdvanceInterruptibly(phase, waitTime, TimeUnit.SECONDS);
+          } catch (TimeoutException e) {
+            totalTimeWaited += waitTime;
+            waitTime = waitTime * 2;
+          } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+            throw new RuntimeException(e);
+          }
         }
-      }
-      if (totalTimeWaited > 0) {
-        // If the phase didn't change, this means that the installed onReady callback had not
-        // been invoked.
-        if (phase == phaser.getPhase()) {
-          LOG.info(
-              "Output channel stalled for {}s, outbound thread {}. See: "
-                  + "https://issues.apache.org/jira/browse/BEAM-4280 for the history for "
-                  + "this issue.",
-              totalTimeWaited,
-              Thread.currentThread().getName());
-        } else {
-          LOG.debug(
-              "Output channel stalled for {}s, outbound thread {}.",
-              totalTimeWaited,
-              Thread.currentThread().getName());
+        if (totalTimeWaited > 0) {
+          // If the phase didn't change, this means that the installed onReady callback had not
+          // been invoked.
+          if (initialPhase == phase) {
+            LOG.info(
+                "Output channel stalled for {}s, outbound thread {}. See: "
+                    + "https://issues.apache.org/jira/browse/BEAM-4280 for the history for "
+                    + "this issue.",
+                totalTimeWaited,
+                Thread.currentThread().getName());
+          } else {
+            LOG.debug(
+                "Output channel stalled for {}s, outbound thread {}.",
+                totalTimeWaited,
+                Thread.currentThread().getName());
+          }
         }
       }
-    }
-    synchronized (outboundObserver) {
       outboundObserver.onNext(value);
     }
   }
 
   @Override
   public void onError(Throwable t) {
-    synchronized (outboundObserver) {
+    synchronized (lock) {
       outboundObserver.onError(t);
     }
   }
 
   @Override
   public void onCompleted() {
-    synchronized (outboundObserver) {
+    synchronized (lock) {
       outboundObserver.onCompleted();
     }
   }
diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DirectStreamObserverTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DirectStreamObserverTest.java
index 2edc93c9c58..6043277216c 100644
--- a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DirectStreamObserverTest.java
+++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DirectStreamObserverTest.java
@@ -17,20 +17,25 @@
  */
 package org.apache.beam.sdk.fn.stream;
 
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.collection.IsCollectionWithSize.hasSize;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.beam.sdk.fn.test.TestExecutors;
 import org.apache.beam.sdk.fn.test.TestExecutors.TestExecutorService;
 import org.apache.beam.sdk.fn.test.TestStreams;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
 import org.junit.Rule;
@@ -77,7 +82,7 @@ public class DirectStreamObserverTest {
     executor.invokeAll(tasks);
     streamObserver.onCompleted();
 
-    // Check that order was maintained.
+    // Check that order was maintained per writer.
     int[] prefixesIndex = new int[prefixes.size()];
     assertEquals(50, onNextValues.size());
     for (String onNextValue : onNextValues) {
@@ -168,4 +173,50 @@ public class DirectStreamObserverTest {
     }
     streamObserver.onCompleted();
   }
+
+  @Test
+  public void testMessageCheckInterval() throws Exception {
+    final AtomicInteger index = new AtomicInteger();
+    ArrayListMultimap<Integer, String> values = ArrayListMultimap.create();
+    final DirectStreamObserver<String> streamObserver =
+        new DirectStreamObserver<>(
+            new AdvancingPhaser(1),
+            TestStreams.withOnNext((String t) -> assertTrue(values.put(index.get(), t)))
+                .withIsReady(
+                    () -> {
+                      index.incrementAndGet();
+                      return true;
+                    })
+                .build(),
+            10);
+
+    List<String> prefixes = ImmutableList.of("0", "1", "2", "3", "4");
+    List<Future<String>> results = new ArrayList<>();
+    for (final String prefix : prefixes) {
+      results.add(
+          executor.submit(
+              () -> {
+                for (int i = 0; i < 10; i++) {
+                  streamObserver.onNext(prefix + i);
+                }
+                return prefix;
+              }));
+    }
+    for (Future<?> result : results) {
+      result.get();
+    }
+    assertEquals(50, values.size());
+    for (Collection<String> valuesPerMessageCheck : values.asMap().values()) {
+      assertThat(valuesPerMessageCheck, hasSize(10));
+    }
+
+    // Check that order was maintained per writer.
+    int[] prefixesIndex = new int[prefixes.size()];
+    for (String onNextValue : values.values()) {
+      int prefix = Integer.parseInt(onNextValue.substring(0, 1));
+      int suffix = Integer.parseInt(onNextValue.substring(1, 2));
+      assertEquals(prefixesIndex[prefix], suffix);
+      prefixesIndex[prefix] += 1;
+    }
+  }
 }