You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2022/04/04 21:32:50 UTC

[beam] branch master updated: [BEAM-13519] Solve race issues when the server responds with an error before the GrpcStateClient finishes being constructed. (#17240)

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 0262ee53c60 [BEAM-13519] Solve race issues when the server responds with an error before the GrpcStateClient finishes being constructed. (#17240)
0262ee53c60 is described below

commit 0262ee53c6018d929a8a40fdf66735cc7e934951
Author: Luke Cwik <lc...@google.com>
AuthorDate: Mon Apr 4 14:32:41 2022 -0700

    [BEAM-13519] Solve race issues when the server responds with an error before the GrpcStateClient finishes being constructed. (#17240)
    
    * [BEAM-13519] Solve race issues when the server responds with an error before the GrpcStateClient finishes.
    
    The issue was that the InboundObserver can be invoked before outboundObserverFactory#outboundObserverFor returns meaning that
    the server is waiting for a response for cache.remove but cache.computeIfAbsent is being invoked at the same time.
    
    Another issue was that the outstandingRequests map could be updated with another request within GrpcStateClient during closeAndCleanup meaning that the CompleteableFuture would never be completed exceptionally.
    
    Passes 1000 times locally now without getting stuck or failing.
---
 .../harness/state/BeamFnStateGrpcClientCache.java  | 105 ++++++++++++++-------
 .../state/BeamFnStateGrpcClientCacheTest.java      |  83 ++++++++--------
 2 files changed, 117 insertions(+), 71 deletions(-)

diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java
index d028ef61d45..e272a98902a 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java
@@ -18,10 +18,9 @@
 package org.apache.beam.fn.harness.state;
 
 import java.io.IOException;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentMap;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnStateGrpc;
@@ -45,7 +44,7 @@ import org.slf4j.LoggerFactory;
 public class BeamFnStateGrpcClientCache {
   private static final Logger LOG = LoggerFactory.getLogger(BeamFnStateGrpcClientCache.class);
 
-  private final ConcurrentMap<ApiServiceDescriptor, BeamFnStateClient> cache;
+  private final Map<ApiServiceDescriptor, BeamFnStateClient> cache;
   private final ManagedChannelFactory channelFactory;
   private final OutboundObserverFactory outboundObserverFactory;
   private final IdGenerator idGenerator;
@@ -59,7 +58,7 @@ public class BeamFnStateGrpcClientCache {
     // This showed a 1-2% improvement in the ProcessBundleBenchmark#testState* benchmarks.
     this.channelFactory = channelFactory.withDirectExecutor();
     this.outboundObserverFactory = outboundObserverFactory;
-    this.cache = new ConcurrentHashMap<>();
+    this.cache = new HashMap<>();
   }
 
   /**
@@ -67,30 +66,53 @@ public class BeamFnStateGrpcClientCache {
    * {@link ApiServiceDescriptor} currently has a {@link BeamFnStateClient} bound to the same
    * channel.
    */
-  public BeamFnStateClient forApiServiceDescriptor(ApiServiceDescriptor apiServiceDescriptor)
-      throws IOException {
-    return cache.computeIfAbsent(apiServiceDescriptor, this::createBeamFnStateClient);
-  }
-
-  private BeamFnStateClient createBeamFnStateClient(ApiServiceDescriptor apiServiceDescriptor) {
-    return new GrpcStateClient(apiServiceDescriptor);
+  public synchronized BeamFnStateClient forApiServiceDescriptor(
+      ApiServiceDescriptor apiServiceDescriptor) throws IOException {
+    // We specifically are synchronized so that we only create one GrpcStateClient at a time
+    // preventing a race where multiple GrpcStateClient objects might be constructed at the same
+    // for the same ApiServiceDescriptor.
+    BeamFnStateClient rval;
+    synchronized (cache) {
+      rval = cache.get(apiServiceDescriptor);
+    }
+    if (rval == null) {
+      // We can't be synchronized on cache while constructing the GrpcStateClient since if the
+      // connection fails, onError may be invoked from the gRPC thread which will invoke
+      // closeAndCleanUp that clears the cache.
+      rval = new GrpcStateClient(apiServiceDescriptor);
+      synchronized (cache) {
+        cache.put(apiServiceDescriptor, rval);
+      }
+    }
+    return rval;
   }
 
   /** A {@link BeamFnStateClient} for a given {@link ApiServiceDescriptor}. */
   private class GrpcStateClient implements BeamFnStateClient {
+    private final Object lock = new Object();
     private final ApiServiceDescriptor apiServiceDescriptor;
-    private final ConcurrentMap<String, CompletableFuture<StateResponse>> outstandingRequests;
+    private final Map<String, CompletableFuture<StateResponse>> outstandingRequests;
     private final StreamObserver<StateRequest> outboundObserver;
     private final ManagedChannel channel;
-    private volatile RuntimeException closed;
+    private RuntimeException closed;
+    private boolean errorDuringConstruction;
 
     private GrpcStateClient(ApiServiceDescriptor apiServiceDescriptor) {
       this.apiServiceDescriptor = apiServiceDescriptor;
-      this.outstandingRequests = new ConcurrentHashMap<>();
+      this.outstandingRequests = new HashMap<>();
       this.channel = channelFactory.forDescriptor(apiServiceDescriptor);
+      this.errorDuringConstruction = false;
       this.outboundObserver =
           outboundObserverFactory.outboundObserverFor(
               BeamFnStateGrpc.newStub(channel)::state, new InboundObserver());
+      // Due to safe object publishing, the InboundObserver may invoke closeAndCleanUp before this
+      // constructor completes. In that case there is a race where outboundObserver may have not
+      // been initialized and hence we invoke onCompleted here.
+      synchronized (lock) {
+        if (errorDuringConstruction) {
+          outboundObserver.onCompleted();
+        }
+      }
     }
 
     @Override
@@ -98,7 +120,13 @@ public class BeamFnStateGrpcClientCache {
       requestBuilder.setId(idGenerator.getId());
       StateRequest request = requestBuilder.build();
       CompletableFuture<StateResponse> response = new CompletableFuture<>();
-      outstandingRequests.put(request.getId(), response);
+      synchronized (lock) {
+        if (closed != null) {
+          response.completeExceptionally(closed);
+          return response;
+        }
+        outstandingRequests.put(request.getId(), response);
+      }
 
       // If the server closes, gRPC will throw an error if onNext is called.
       LOG.debug("Sending StateRequest {}", request);
@@ -106,27 +134,33 @@ public class BeamFnStateGrpcClientCache {
       return response;
     }
 
-    private synchronized void closeAndCleanUp(RuntimeException cause) {
-      if (closed != null) {
-        return;
-      }
-      cache.remove(apiServiceDescriptor);
-      closed = cause;
-
-      // Make a copy of the map to make the view of the outstanding requests consistent.
-      Map<String, CompletableFuture<StateResponse>> outstandingRequestsCopy =
-          new ConcurrentHashMap<>(outstandingRequests);
+    private void closeAndCleanUp(RuntimeException cause) {
+      synchronized (lock) {
+        if (closed != null) {
+          return;
+        }
+        closed = cause;
 
-      if (outstandingRequestsCopy.isEmpty()) {
-        outboundObserver.onCompleted();
-        return;
-      }
+        synchronized (cache) {
+          cache.remove(apiServiceDescriptor);
+        }
 
-      outstandingRequests.clear();
-      LOG.error("BeamFnState failed, clearing outstanding requests {}", outstandingRequestsCopy);
+        if (!outstandingRequests.isEmpty()) {
+          LOG.error("BeamFnState failed, clearing outstanding requests {}", outstandingRequests);
+          for (CompletableFuture<StateResponse> entry : outstandingRequests.values()) {
+            entry.completeExceptionally(cause);
+          }
+          outstandingRequests.clear();
+        }
 
-      for (CompletableFuture<StateResponse> entry : outstandingRequestsCopy.values()) {
-        entry.completeExceptionally(cause);
+        // Due to safe object publishing, outboundObserver may be null since InboundObserver may
+        // call closeAndCleanUp before the GrpcStateClient finishes construction. In this case
+        // we defer invoking onCompleted to the GrpcStateClient constructor.
+        if (outboundObserver == null) {
+          errorDuringConstruction = true;
+        } else {
+          outboundObserver.onCompleted();
+        }
       }
     }
 
@@ -143,7 +177,10 @@ public class BeamFnStateGrpcClientCache {
       @Override
       public void onNext(StateResponse value) {
         LOG.debug("Received StateResponse {}", value);
-        CompletableFuture<StateResponse> responseFuture = outstandingRequests.remove(value.getId());
+        CompletableFuture<StateResponse> responseFuture;
+        synchronized (lock) {
+          responseFuture = outstandingRequests.remove(value.getId());
+        }
         if (responseFuture == null) {
           LOG.warn("Dropped unknown StateResponse {}", value);
           return;
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCacheTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCacheTest.java
index 1615a59cb9a..a729755fc12 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCacheTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCacheTest.java
@@ -28,14 +28,19 @@ import java.util.UUID;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
 import java.util.concurrent.LinkedBlockingQueue;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnStateGrpc;
+import org.apache.beam.model.fnexecution.v1.BeamFnStateGrpc.BeamFnStateImplBase;
 import org.apache.beam.model.pipeline.v1.Endpoints;
 import org.apache.beam.sdk.fn.IdGenerators;
 import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
 import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
+import org.apache.beam.sdk.fn.test.TestExecutors;
+import org.apache.beam.sdk.fn.test.TestExecutors.TestExecutorService;
 import org.apache.beam.sdk.fn.test.TestStreams;
 import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.Server;
 import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.Status;
@@ -46,7 +51,7 @@ import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.stub.StreamObserver;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
 import org.junit.After;
 import org.junit.Before;
-import org.junit.Ignore;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -59,6 +64,8 @@ public class BeamFnStateGrpcClientCacheTest {
   private static final String TEST_ERROR = "TEST ERROR";
   private static final String SERVER_ERROR = "SERVER ERROR";
 
+  @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool);
+
   private Endpoints.ApiServiceDescriptor apiServiceDescriptor;
   private Server testServer;
   private BeamFnStateGrpcClientCache clientCache;
@@ -103,7 +110,6 @@ public class BeamFnStateGrpcClientCacheTest {
   }
 
   @Test
-  @Ignore("(BEAM-13519) Java precommit timing out")
   public void testCachingOfClient() throws Exception {
     Endpoints.ApiServiceDescriptor otherApiServiceDescriptor =
         Endpoints.ApiServiceDescriptor.newBuilder()
@@ -112,18 +118,17 @@ public class BeamFnStateGrpcClientCacheTest {
     Server testServer2 =
         InProcessServerBuilder.forName(otherApiServiceDescriptor.getUrl())
             .addService(
-                new BeamFnStateGrpc.BeamFnStateImplBase() {
+                new BeamFnStateImplBase() {
                   @Override
                   public StreamObserver<StateRequest> state(
                       StreamObserver<StateResponse> outboundObserver) {
-                    throw new IllegalStateException("Unexpected in test.");
+                    throw new RuntimeException();
                   }
                 })
             .build();
     testServer2.start();
 
     try {
-
       assertSame(
           clientCache.forApiServiceDescriptor(apiServiceDescriptor),
           clientCache.forApiServiceDescriptor(apiServiceDescriptor));
@@ -164,25 +169,27 @@ public class BeamFnStateGrpcClientCacheTest {
   }
 
   @Test
+  // The checker erroneously flags that the CompletableFuture is not being resolved since it is the
+  // result to Executor#submit.
+  @SuppressWarnings("FutureReturnValueIgnored")
   public void testServerErrorCausesPendingAndFutureCallsToFail() throws Exception {
     BeamFnStateClient client = clientCache.forApiServiceDescriptor(apiServiceDescriptor);
 
-    CompletableFuture<StateResponse> inflight =
-        client.handle(StateRequest.newBuilder().setInstructionId(SUCCESS));
-
-    // Wait for the client to connect.
-    StreamObserver<StateResponse> outboundServerObserver = outboundServerObservers.take();
-    // Send an error from the server.
-    outboundServerObserver.onError(
-        new StatusRuntimeException(Status.INTERNAL.withDescription(SERVER_ERROR)));
-
-    try {
-      inflight.get();
-      fail("Expected unsuccessful response due to server error");
-    } catch (ExecutionException e) {
-      assertThat(e.toString(), containsString(SERVER_ERROR));
-    }
-
+    Future<CompletableFuture<StateResponse>> stateResponse =
+        executor.submit(() -> client.handle(StateRequest.newBuilder().setInstructionId(SUCCESS)));
+    Future<Void> serverResponse =
+        executor.submit(
+            () -> {
+              // Wait for the client to connect.
+              StreamObserver<StateResponse> outboundServerObserver = outboundServerObservers.take();
+              // Send an error from the server.
+              outboundServerObserver.onError(
+                  new StatusRuntimeException(Status.INTERNAL.withDescription(SERVER_ERROR)));
+              return null;
+            });
+
+    CompletableFuture<StateResponse> inflight = stateResponse.get();
+    serverResponse.get();
     try {
       inflight.get();
       fail("Expected unsuccessful response due to server error");
@@ -192,27 +199,29 @@ public class BeamFnStateGrpcClientCacheTest {
   }
 
   @Test
+  // The checker erroneously flags that the CompletableFuture is not being resolved since it is the
+  // result to Executor#submit.
+  @SuppressWarnings("FutureReturnValueIgnored")
   public void testServerCompletionCausesPendingAndFutureCallsToFail() throws Exception {
     BeamFnStateClient client = clientCache.forApiServiceDescriptor(apiServiceDescriptor);
 
-    CompletableFuture<StateResponse> inflight =
-        client.handle(StateRequest.newBuilder().setInstructionId(SUCCESS));
-
-    // Wait for the client to connect.
-    StreamObserver<StateResponse> outboundServerObserver = outboundServerObservers.take();
-    // Send that the server is done.
-    outboundServerObserver.onCompleted();
-
+    Future<CompletableFuture<StateResponse>> stateResponse =
+        executor.submit(() -> client.handle(StateRequest.newBuilder().setInstructionId(SUCCESS)));
+    Future<Void> serverResponse =
+        executor.submit(
+            () -> {
+              // Wait for the client to connect.
+              StreamObserver<StateResponse> outboundServerObserver = outboundServerObservers.take();
+              // Send that the server is done.
+              outboundServerObserver.onCompleted();
+              return null;
+            });
+
+    CompletableFuture<StateResponse> inflight = stateResponse.get();
+    serverResponse.get();
     try {
       inflight.get();
-      fail("Expected unsuccessful response due to server completion");
-    } catch (ExecutionException e) {
-      assertThat(e.toString(), containsString("Server hanged up"));
-    }
-
-    try {
-      inflight.get();
-      fail("Expected unsuccessful response due to server completion");
+      fail("Expected unsuccessful response due to server error");
     } catch (ExecutionException e) {
       assertThat(e.toString(), containsString("Server hanged up"));
     }