You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/04/04 21:03:26 UTC

[beam] branch master updated: [BEAM-14157] GrpcWindmillServer: Use stream specific boolean to do client closed check (#17191)

This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 747e94b62d2 [BEAM-14157] GrpcWindmillServer: Use stream specific boolean to do client closed check (#17191)
747e94b62d2 is described below

commit 747e94b62d215e3456eea8ed5b7c68f9cc2c9242
Author: Arun Pandian <ar...@gmail.com>
AuthorDate: Mon Apr 4 14:03:18 2022 -0700

    [BEAM-14157] GrpcWindmillServer: Use stream specific boolean to do client closed check (#17191)
    
    * [BEAM-14157] GrpcWindmillServer: Use stream specific boolean to do client closed check
    
    This is a follow up to #17162. An AbstractWindmillStream can have more than one grpc stream during its lifetime, new streams can be created after client closed for sending pending requests. So it is not correct to check `if(clientClosed)` in `send()`, this PR adds a new grpc stream level boolean to do the closed check in `send()`.
    
    * [BEAM-14157] Add unit test testing CommitWorkStream retries around stream closing
    
    * [BEAM-14157] review comments
    
    * [BEAM-14157] review comments
    
    * [BEAM-14157] review comments
    
    * [BEAM-14157] fix test
    
    * [BEAM-14157] fix test
    
    Co-authored-by: Arun Pandian <pa...@google.com>
---
 .../worker/windmill/GrpcWindmillServer.java        |  11 +-
 .../worker/windmill/GrpcWindmillServerTest.java    | 272 +++++++++++++++------
 2 files changed, 209 insertions(+), 74 deletions(-)

diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java
index 6631ffa13e8..e914ef160de 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java
@@ -632,6 +632,8 @@ public class GrpcWindmillServer extends WindmillServerStub {
     // The following should be protected by synchronizing on this, except for
     // the atomics which may be read atomically for status pages.
     private StreamObserver<RequestT> requestObserver;
+    // Indicates if the current stream in requestObserver is closed by calling close() method
+    private final AtomicBoolean streamClosed = new AtomicBoolean();
     private final AtomicLong startTimeMs = new AtomicLong();
     private final AtomicLong lastSendTimeMs = new AtomicLong();
     private final AtomicLong lastResponseTimeMs = new AtomicLong();
@@ -663,7 +665,7 @@ public class GrpcWindmillServer extends WindmillServerStub {
     protected final void send(RequestT request) {
       lastSendTimeMs.set(Instant.now().getMillis());
       synchronized (this) {
-        if (clientClosed.get()) {
+        if (streamClosed.get()) {
           throw new IllegalStateException("Send called on a client closed stream.");
         }
         requestObserver.onNext(request);
@@ -681,6 +683,7 @@ public class GrpcWindmillServer extends WindmillServerStub {
             startTimeMs.set(Instant.now().getMillis());
             lastResponseTimeMs.set(0);
             requestObserver = streamObserverFactory.from(clientFactory, new ResponseObserver());
+            streamClosed.set(false);
             onNewStream();
             if (clientClosed.get()) {
               close();
@@ -742,10 +745,11 @@ public class GrpcWindmillServer extends WindmillServerStub {
         writer.format(", %dms backoff remaining", sleepLeft);
       }
       writer.format(
-          ", current stream is %dms old, last send %dms, last response %dms",
+          ", current stream is %dms old, last send %dms, last response %dms, closed: %s",
           debugDuration(nowMs, startTimeMs.get()),
           debugDuration(nowMs, lastSendTimeMs.get()),
-          debugDuration(nowMs, lastResponseTimeMs.get()));
+          debugDuration(nowMs, lastResponseTimeMs.get()),
+          streamClosed.get());
     }
 
     // Don't require synchronization on stream, see the appendSummaryHtml comment.
@@ -838,6 +842,7 @@ public class GrpcWindmillServer extends WindmillServerStub {
       // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream.
       clientClosed.set(true);
       requestObserver.onCompleted();
+      streamClosed.set(true);
     }
 
     @Override
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java
index c5d7b0c0f32..64a31f36831 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java
@@ -19,6 +19,7 @@ package org.apache.beam.runners.dataflow.worker.windmill;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 import java.io.InputStream;
 import java.io.SequenceInputStream;
@@ -28,11 +29,14 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
@@ -491,11 +495,96 @@ public class GrpcWindmillServerTest {
         .build();
   }
 
+  // This server receives WorkItemCommitRequests, and verifies they are equal to the provided
+  // commitRequest.
+  private StreamObserver<StreamingCommitWorkRequest> getTestCommitStreamObserver(
+      StreamObserver<StreamingCommitResponse> responseObserver,
+      Map<Long, WorkItemCommitRequest> commitRequests) {
+    return new StreamObserver<StreamingCommitWorkRequest>() {
+      boolean sawHeader = false;
+      InputStream buffer = null;
+      long remainingBytes = 0;
+      ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver);
+
+      @Override
+      public void onNext(StreamingCommitWorkRequest request) {
+        maybeInjectError(responseObserver);
+
+        if (!sawHeader) {
+          errorCollector.checkThat(
+              request.getHeader(),
+              Matchers.equalTo(
+                  JobHeader.newBuilder()
+                      .setJobId("job")
+                      .setProjectId("project")
+                      .setWorkerId("worker")
+                      .build()));
+          sawHeader = true;
+          LOG.info("Received header");
+        } else {
+          boolean first = true;
+          LOG.info("Received request with {} chunks", request.getCommitChunkCount());
+          for (StreamingCommitRequestChunk chunk : request.getCommitChunkList()) {
+            assertTrue(chunk.getSerializedWorkItemCommit().size() <= STREAM_CHUNK_SIZE);
+            if (first || chunk.hasComputationId()) {
+              errorCollector.checkThat(chunk.getComputationId(), Matchers.equalTo("computation"));
+            }
+
+            if (remainingBytes != 0) {
+              errorCollector.checkThat(buffer, Matchers.notNullValue());
+              errorCollector.checkThat(
+                  remainingBytes,
+                  Matchers.is(
+                      chunk.getSerializedWorkItemCommit().size()
+                          + chunk.getRemainingBytesForWorkItem()));
+              buffer =
+                  new SequenceInputStream(buffer, chunk.getSerializedWorkItemCommit().newInput());
+            } else {
+              errorCollector.checkThat(buffer, Matchers.nullValue());
+              buffer = chunk.getSerializedWorkItemCommit().newInput();
+            }
+            remainingBytes = chunk.getRemainingBytesForWorkItem();
+            if (remainingBytes == 0) {
+              try {
+                WorkItemCommitRequest received = WorkItemCommitRequest.parseFrom(buffer);
+                errorCollector.checkThat(
+                    received, Matchers.equalTo(commitRequests.get(received.getWorkToken())));
+                try {
+                  responseObserver.onNext(
+                      StreamingCommitResponse.newBuilder()
+                          .addRequestId(chunk.getRequestId())
+                          .build());
+                } catch (IllegalStateException e) {
+                  // Stream is closed.
+                }
+              } catch (Exception e) {
+                errorCollector.addError(e);
+              }
+              buffer = null;
+            } else {
+              errorCollector.checkThat(first, Matchers.is(true));
+            }
+            first = false;
+          }
+        }
+      }
+
+      @Override
+      public void onError(Throwable throwable) {}
+
+      @Override
+      public void onCompleted() {
+        injector.cancel();
+        responseObserver.onCompleted();
+      }
+    };
+  }
+
   @Test
   public void testStreamingCommit() throws Exception {
     List<WorkItemCommitRequest> commitRequestList = new ArrayList<>();
     List<CountDownLatch> latches = new ArrayList<>();
-    Map<Long, WorkItemCommitRequest> commitRequests = new HashMap<>();
+    Map<Long, WorkItemCommitRequest> commitRequests = new ConcurrentHashMap<>();
     for (int i = 0; i < 500; ++i) {
       // Build some requests of varying size with a few big ones.
       WorkItemCommitRequest request = makeCommitRequest(i, i * (i < 480 ? 8 : 128));
@@ -505,92 +594,94 @@ public class GrpcWindmillServerTest {
     }
     Collections.shuffle(commitRequestList);
 
-    // This server receives WorkItemCommitRequests, and verifies they are equal to the above
-    // commitRequest.
     serviceRegistry.addService(
         new CloudWindmillServiceV1Alpha1ImplBase() {
           @Override
           public StreamObserver<StreamingCommitWorkRequest> commitWorkStream(
               StreamObserver<StreamingCommitResponse> responseObserver) {
-            return new StreamObserver<StreamingCommitWorkRequest>() {
-              boolean sawHeader = false;
-              InputStream buffer = null;
-              long remainingBytes = 0;
-              ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver);
+            return getTestCommitStreamObserver(responseObserver, commitRequests);
+          }
+        });
 
-              @Override
-              public void onNext(StreamingCommitWorkRequest request) {
-                maybeInjectError(responseObserver);
+    // Make the commit requests, waiting for each of them to be verified and acknowledged.
+    CommitWorkStream stream = client.commitWorkStream();
+    for (int i = 0; i < commitRequestList.size(); ) {
+      final CountDownLatch latch = latches.get(i);
+      if (stream.commitWorkItem(
+          "computation",
+          commitRequestList.get(i),
+          (CommitStatus status) -> {
+            assertEquals(status, CommitStatus.OK);
+            latch.countDown();
+          })) {
+        i++;
+      } else {
+        stream.flush();
+      }
+    }
+    stream.flush();
+    stream.close();
+    for (CountDownLatch latch : latches) {
+      assertTrue(latch.await(1, TimeUnit.MINUTES));
+    }
+    assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS));
+  }
 
-                if (!sawHeader) {
-                  errorCollector.checkThat(
-                      request.getHeader(),
-                      Matchers.equalTo(
-                          JobHeader.newBuilder()
-                              .setJobId("job")
-                              .setProjectId("project")
-                              .setWorkerId("worker")
-                              .build()));
-                  sawHeader = true;
-                  LOG.info("Received header");
-                } else {
-                  boolean first = true;
-                  LOG.info("Received request with {} chunks", request.getCommitChunkCount());
-                  for (StreamingCommitRequestChunk chunk : request.getCommitChunkList()) {
-                    assertTrue(chunk.getSerializedWorkItemCommit().size() <= STREAM_CHUNK_SIZE);
-                    if (first || chunk.hasComputationId()) {
-                      errorCollector.checkThat(
-                          chunk.getComputationId(), Matchers.equalTo("computation"));
-                    }
+  @Test
+  // Tests stream retries on server errors before and after `close()`
+  public void testStreamingCommitClosedStream() throws Exception {
+    List<WorkItemCommitRequest> commitRequestList = new ArrayList<>();
+    List<CountDownLatch> latches = new ArrayList<>();
+    Map<Long, WorkItemCommitRequest> commitRequests = new ConcurrentHashMap<>();
+    AtomicBoolean shouldServerReturnError = new AtomicBoolean(true);
+    AtomicBoolean isClientClosed = new AtomicBoolean(false);
+    AtomicInteger errorsBeforeClose = new AtomicInteger();
+    AtomicInteger errorsAfterClose = new AtomicInteger();
+    for (int i = 0; i < 500; ++i) {
+      // Build some requests of varying size with a few big ones.
+      WorkItemCommitRequest request = makeCommitRequest(i, i * (i < 480 ? 8 : 128));
+      commitRequestList.add(request);
+      commitRequests.put((long) i, request);
+      latches.add(new CountDownLatch(1));
+    }
+    Collections.shuffle(commitRequestList);
 
-                    if (remainingBytes != 0) {
-                      errorCollector.checkThat(buffer, Matchers.notNullValue());
-                      errorCollector.checkThat(
-                          remainingBytes,
-                          Matchers.is(
-                              chunk.getSerializedWorkItemCommit().size()
-                                  + chunk.getRemainingBytesForWorkItem()));
-                      buffer =
-                          new SequenceInputStream(
-                              buffer, chunk.getSerializedWorkItemCommit().newInput());
-                    } else {
-                      errorCollector.checkThat(buffer, Matchers.nullValue());
-                      buffer = chunk.getSerializedWorkItemCommit().newInput();
-                    }
-                    remainingBytes = chunk.getRemainingBytesForWorkItem();
-                    if (remainingBytes == 0) {
-                      try {
-                        WorkItemCommitRequest received = WorkItemCommitRequest.parseFrom(buffer);
-                        errorCollector.checkThat(
-                            received,
-                            Matchers.equalTo(commitRequests.get(received.getWorkToken())));
-                        try {
-                          responseObserver.onNext(
-                              StreamingCommitResponse.newBuilder()
-                                  .addRequestId(chunk.getRequestId())
-                                  .build());
-                        } catch (IllegalStateException e) {
-                          // Stream is closed.
-                        }
-                      } catch (Exception e) {
-                        errorCollector.addError(e);
-                      }
-                      buffer = null;
+    // This server returns errors if shouldServerReturnError is true, else returns valid responses.
+    serviceRegistry.addService(
+        new CloudWindmillServiceV1Alpha1ImplBase() {
+          @Override
+          public StreamObserver<StreamingCommitWorkRequest> commitWorkStream(
+              StreamObserver<StreamingCommitResponse> responseObserver) {
+            StreamObserver<StreamingCommitWorkRequest> testCommitStreamObserver =
+                getTestCommitStreamObserver(responseObserver, commitRequests);
+            return new StreamObserver<StreamingCommitWorkRequest>() {
+              @Override
+              public void onNext(StreamingCommitWorkRequest request) {
+                if (shouldServerReturnError.get()) {
+                  try {
+                    responseObserver.onError(
+                        new RuntimeException("shouldServerReturnError = true"));
+                    if (isClientClosed.get()) {
+                      errorsAfterClose.incrementAndGet();
                     } else {
-                      errorCollector.checkThat(first, Matchers.is(true));
+                      errorsBeforeClose.incrementAndGet();
                     }
-                    first = false;
+                  } catch (IllegalStateException e) {
+                    // The stream is already closed.
                   }
+                } else {
+                  testCommitStreamObserver.onNext(request);
                 }
               }
 
               @Override
-              public void onError(Throwable throwable) {}
+              public void onError(Throwable throwable) {
+                testCommitStreamObserver.onError(throwable);
+              }
 
               @Override
               public void onCompleted() {
-                injector.cancel();
-                responseObserver.onCompleted();
+                testCommitStreamObserver.onCompleted();
               }
             };
           }
@@ -613,11 +704,50 @@ public class GrpcWindmillServerTest {
       }
     }
     stream.flush();
-    for (CountDownLatch latch : latches) {
-      assertTrue(latch.await(1, TimeUnit.MINUTES));
+
+    long deadline = System.currentTimeMillis() + 60_000; // 1 min
+    while (true) {
+      Thread.sleep(100);
+      int tmpErrorsBeforeClose = errorsBeforeClose.get();
+      // wait for at least 1 errors before close
+      if (tmpErrorsBeforeClose > 0) {
+        break;
+      }
+      if (System.currentTimeMillis() > deadline) {
+        // Control should not reach here if the test is working as expected
+        fail(
+            String.format(
+                "Expected errors not sent by server errorsBeforeClose: %s"
+                    + " \n Should not reach here if the test is working as expected.",
+                tmpErrorsBeforeClose));
+      }
     }
 
     stream.close();
+    isClientClosed.set(true);
+
+    deadline = System.currentTimeMillis() + 60_000; // 1 min
+    while (true) {
+      Thread.sleep(100);
+      int tmpErrorsAfterClose = errorsAfterClose.get();
+      // wait for at least 1 errors after close
+      if (tmpErrorsAfterClose > 0) {
+        break;
+      }
+      if (System.currentTimeMillis() > deadline) {
+        // Control should not reach here if the test is working as expected
+        fail(
+            String.format(
+                "Expected errors not sent by server errorsAfterClose: %s"
+                    + " \n Should not reach here if the test is working as expected.",
+                tmpErrorsAfterClose));
+      }
+    }
+
+    shouldServerReturnError.set(false);
+    for (CountDownLatch latch : latches) {
+      assertTrue(latch.await(1, TimeUnit.MINUTES));
+    }
     assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS));
   }