You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/04/04 21:03:26 UTC
[beam] branch master updated: [BEAM-14157] GrpcWindmillServer: Use stream specific boolean to do client closed check (#17191)
This is an automated email from the ASF dual-hosted git repository.
bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 747e94b62d2 [BEAM-14157] GrpcWindmillServer: Use stream specific boolean to do client closed check (#17191)
747e94b62d2 is described below
commit 747e94b62d215e3456eea8ed5b7c68f9cc2c9242
Author: Arun Pandian <ar...@gmail.com>
AuthorDate: Mon Apr 4 14:03:18 2022 -0700
[BEAM-14157] GrpcWindmillServer: Use stream specific boolean to do client closed check (#17191)
* [BEAM-14157] GrpcWindmillServer: Use stream specific boolean to do client closed check
This is a follow up to #17162. An AbstractWindmillStream can have more than one grpc stream during its lifetime, new streams can be created after client closed for sending pending requests. So it is not correct to check `if(clientClosed)` in `send()`, this PR adds a new grpc stream level boolean to do the closed check in `send()`.
* [BEAM-14157] Add unit test testing CommitWorkStream retries around stream closing
* [BEAM-14157] review comments
* [BEAM-14157] review comments
* [BEAM-14157] review comments
* [BEAM-14157] fix test
* [BEAM-14157] fix test
Co-authored-by: Arun Pandian <pa...@google.com>
---
.../worker/windmill/GrpcWindmillServer.java | 11 +-
.../worker/windmill/GrpcWindmillServerTest.java | 272 +++++++++++++++------
2 files changed, 209 insertions(+), 74 deletions(-)
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java
index 6631ffa13e8..e914ef160de 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java
@@ -632,6 +632,8 @@ public class GrpcWindmillServer extends WindmillServerStub {
// The following should be protected by synchronizing on this, except for
// the atomics which may be read atomically for status pages.
private StreamObserver<RequestT> requestObserver;
+ // Indicates if the current stream in requestObserver is closed by calling close() method
+ private final AtomicBoolean streamClosed = new AtomicBoolean();
private final AtomicLong startTimeMs = new AtomicLong();
private final AtomicLong lastSendTimeMs = new AtomicLong();
private final AtomicLong lastResponseTimeMs = new AtomicLong();
@@ -663,7 +665,7 @@ public class GrpcWindmillServer extends WindmillServerStub {
protected final void send(RequestT request) {
lastSendTimeMs.set(Instant.now().getMillis());
synchronized (this) {
- if (clientClosed.get()) {
+ if (streamClosed.get()) {
throw new IllegalStateException("Send called on a client closed stream.");
}
requestObserver.onNext(request);
@@ -681,6 +683,7 @@ public class GrpcWindmillServer extends WindmillServerStub {
startTimeMs.set(Instant.now().getMillis());
lastResponseTimeMs.set(0);
requestObserver = streamObserverFactory.from(clientFactory, new ResponseObserver());
+ streamClosed.set(false);
onNewStream();
if (clientClosed.get()) {
close();
@@ -742,10 +745,11 @@ public class GrpcWindmillServer extends WindmillServerStub {
writer.format(", %dms backoff remaining", sleepLeft);
}
writer.format(
- ", current stream is %dms old, last send %dms, last response %dms",
+ ", current stream is %dms old, last send %dms, last response %dms, closed: %s",
debugDuration(nowMs, startTimeMs.get()),
debugDuration(nowMs, lastSendTimeMs.get()),
- debugDuration(nowMs, lastResponseTimeMs.get()));
+ debugDuration(nowMs, lastResponseTimeMs.get()),
+ streamClosed.get());
}
// Don't require synchronization on stream, see the appendSummaryHtml comment.
@@ -838,6 +842,7 @@ public class GrpcWindmillServer extends WindmillServerStub {
// Synchronization of close and onCompleted necessary for correct retry logic in onNewStream.
clientClosed.set(true);
requestObserver.onCompleted();
+ streamClosed.set(true);
}
@Override
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java
index c5d7b0c0f32..64a31f36831 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java
@@ -19,6 +19,7 @@ package org.apache.beam.runners.dataflow.worker.windmill;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
import java.io.InputStream;
import java.io.SequenceInputStream;
@@ -28,11 +29,14 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
@@ -491,11 +495,96 @@ public class GrpcWindmillServerTest {
.build();
}
+ // This server receives WorkItemCommitRequests, and verifies they are equal to the provided
+ // commitRequest.
+ private StreamObserver<StreamingCommitWorkRequest> getTestCommitStreamObserver(
+ StreamObserver<StreamingCommitResponse> responseObserver,
+ Map<Long, WorkItemCommitRequest> commitRequests) {
+ return new StreamObserver<StreamingCommitWorkRequest>() {
+ boolean sawHeader = false;
+ InputStream buffer = null;
+ long remainingBytes = 0;
+ ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver);
+
+ @Override
+ public void onNext(StreamingCommitWorkRequest request) {
+ maybeInjectError(responseObserver);
+
+ if (!sawHeader) {
+ errorCollector.checkThat(
+ request.getHeader(),
+ Matchers.equalTo(
+ JobHeader.newBuilder()
+ .setJobId("job")
+ .setProjectId("project")
+ .setWorkerId("worker")
+ .build()));
+ sawHeader = true;
+ LOG.info("Received header");
+ } else {
+ boolean first = true;
+ LOG.info("Received request with {} chunks", request.getCommitChunkCount());
+ for (StreamingCommitRequestChunk chunk : request.getCommitChunkList()) {
+ assertTrue(chunk.getSerializedWorkItemCommit().size() <= STREAM_CHUNK_SIZE);
+ if (first || chunk.hasComputationId()) {
+ errorCollector.checkThat(chunk.getComputationId(), Matchers.equalTo("computation"));
+ }
+
+ if (remainingBytes != 0) {
+ errorCollector.checkThat(buffer, Matchers.notNullValue());
+ errorCollector.checkThat(
+ remainingBytes,
+ Matchers.is(
+ chunk.getSerializedWorkItemCommit().size()
+ + chunk.getRemainingBytesForWorkItem()));
+ buffer =
+ new SequenceInputStream(buffer, chunk.getSerializedWorkItemCommit().newInput());
+ } else {
+ errorCollector.checkThat(buffer, Matchers.nullValue());
+ buffer = chunk.getSerializedWorkItemCommit().newInput();
+ }
+ remainingBytes = chunk.getRemainingBytesForWorkItem();
+ if (remainingBytes == 0) {
+ try {
+ WorkItemCommitRequest received = WorkItemCommitRequest.parseFrom(buffer);
+ errorCollector.checkThat(
+ received, Matchers.equalTo(commitRequests.get(received.getWorkToken())));
+ try {
+ responseObserver.onNext(
+ StreamingCommitResponse.newBuilder()
+ .addRequestId(chunk.getRequestId())
+ .build());
+ } catch (IllegalStateException e) {
+ // Stream is closed.
+ }
+ } catch (Exception e) {
+ errorCollector.addError(e);
+ }
+ buffer = null;
+ } else {
+ errorCollector.checkThat(first, Matchers.is(true));
+ }
+ first = false;
+ }
+ }
+ }
+
+ @Override
+ public void onError(Throwable throwable) {}
+
+ @Override
+ public void onCompleted() {
+ injector.cancel();
+ responseObserver.onCompleted();
+ }
+ };
+ }
+
@Test
public void testStreamingCommit() throws Exception {
List<WorkItemCommitRequest> commitRequestList = new ArrayList<>();
List<CountDownLatch> latches = new ArrayList<>();
- Map<Long, WorkItemCommitRequest> commitRequests = new HashMap<>();
+ Map<Long, WorkItemCommitRequest> commitRequests = new ConcurrentHashMap<>();
for (int i = 0; i < 500; ++i) {
// Build some requests of varying size with a few big ones.
WorkItemCommitRequest request = makeCommitRequest(i, i * (i < 480 ? 8 : 128));
@@ -505,92 +594,94 @@ public class GrpcWindmillServerTest {
}
Collections.shuffle(commitRequestList);
- // This server receives WorkItemCommitRequests, and verifies they are equal to the above
- // commitRequest.
serviceRegistry.addService(
new CloudWindmillServiceV1Alpha1ImplBase() {
@Override
public StreamObserver<StreamingCommitWorkRequest> commitWorkStream(
StreamObserver<StreamingCommitResponse> responseObserver) {
- return new StreamObserver<StreamingCommitWorkRequest>() {
- boolean sawHeader = false;
- InputStream buffer = null;
- long remainingBytes = 0;
- ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver);
+ return getTestCommitStreamObserver(responseObserver, commitRequests);
+ }
+ });
- @Override
- public void onNext(StreamingCommitWorkRequest request) {
- maybeInjectError(responseObserver);
+ // Make the commit requests, waiting for each of them to be verified and acknowledged.
+ CommitWorkStream stream = client.commitWorkStream();
+ for (int i = 0; i < commitRequestList.size(); ) {
+ final CountDownLatch latch = latches.get(i);
+ if (stream.commitWorkItem(
+ "computation",
+ commitRequestList.get(i),
+ (CommitStatus status) -> {
+ assertEquals(status, CommitStatus.OK);
+ latch.countDown();
+ })) {
+ i++;
+ } else {
+ stream.flush();
+ }
+ }
+ stream.flush();
+ stream.close();
+ for (CountDownLatch latch : latches) {
+ assertTrue(latch.await(1, TimeUnit.MINUTES));
+ }
+ assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS));
+ }
- if (!sawHeader) {
- errorCollector.checkThat(
- request.getHeader(),
- Matchers.equalTo(
- JobHeader.newBuilder()
- .setJobId("job")
- .setProjectId("project")
- .setWorkerId("worker")
- .build()));
- sawHeader = true;
- LOG.info("Received header");
- } else {
- boolean first = true;
- LOG.info("Received request with {} chunks", request.getCommitChunkCount());
- for (StreamingCommitRequestChunk chunk : request.getCommitChunkList()) {
- assertTrue(chunk.getSerializedWorkItemCommit().size() <= STREAM_CHUNK_SIZE);
- if (first || chunk.hasComputationId()) {
- errorCollector.checkThat(
- chunk.getComputationId(), Matchers.equalTo("computation"));
- }
+ @Test
+ // Tests stream retries on server errors before and after `close()`
+ public void testStreamingCommitClosedStream() throws Exception {
+ List<WorkItemCommitRequest> commitRequestList = new ArrayList<>();
+ List<CountDownLatch> latches = new ArrayList<>();
+ Map<Long, WorkItemCommitRequest> commitRequests = new ConcurrentHashMap<>();
+ AtomicBoolean shouldServerReturnError = new AtomicBoolean(true);
+ AtomicBoolean isClientClosed = new AtomicBoolean(false);
+ AtomicInteger errorsBeforeClose = new AtomicInteger();
+ AtomicInteger errorsAfterClose = new AtomicInteger();
+ for (int i = 0; i < 500; ++i) {
+ // Build some requests of varying size with a few big ones.
+ WorkItemCommitRequest request = makeCommitRequest(i, i * (i < 480 ? 8 : 128));
+ commitRequestList.add(request);
+ commitRequests.put((long) i, request);
+ latches.add(new CountDownLatch(1));
+ }
+ Collections.shuffle(commitRequestList);
- if (remainingBytes != 0) {
- errorCollector.checkThat(buffer, Matchers.notNullValue());
- errorCollector.checkThat(
- remainingBytes,
- Matchers.is(
- chunk.getSerializedWorkItemCommit().size()
- + chunk.getRemainingBytesForWorkItem()));
- buffer =
- new SequenceInputStream(
- buffer, chunk.getSerializedWorkItemCommit().newInput());
- } else {
- errorCollector.checkThat(buffer, Matchers.nullValue());
- buffer = chunk.getSerializedWorkItemCommit().newInput();
- }
- remainingBytes = chunk.getRemainingBytesForWorkItem();
- if (remainingBytes == 0) {
- try {
- WorkItemCommitRequest received = WorkItemCommitRequest.parseFrom(buffer);
- errorCollector.checkThat(
- received,
- Matchers.equalTo(commitRequests.get(received.getWorkToken())));
- try {
- responseObserver.onNext(
- StreamingCommitResponse.newBuilder()
- .addRequestId(chunk.getRequestId())
- .build());
- } catch (IllegalStateException e) {
- // Stream is closed.
- }
- } catch (Exception e) {
- errorCollector.addError(e);
- }
- buffer = null;
+ // This server returns errors if shouldServerReturnError is true, else returns valid responses.
+ serviceRegistry.addService(
+ new CloudWindmillServiceV1Alpha1ImplBase() {
+ @Override
+ public StreamObserver<StreamingCommitWorkRequest> commitWorkStream(
+ StreamObserver<StreamingCommitResponse> responseObserver) {
+ StreamObserver<StreamingCommitWorkRequest> testCommitStreamObserver =
+ getTestCommitStreamObserver(responseObserver, commitRequests);
+ return new StreamObserver<StreamingCommitWorkRequest>() {
+ @Override
+ public void onNext(StreamingCommitWorkRequest request) {
+ if (shouldServerReturnError.get()) {
+ try {
+ responseObserver.onError(
+ new RuntimeException("shouldServerReturnError = true"));
+ if (isClientClosed.get()) {
+ errorsAfterClose.incrementAndGet();
} else {
- errorCollector.checkThat(first, Matchers.is(true));
+ errorsBeforeClose.incrementAndGet();
}
- first = false;
+ } catch (IllegalStateException e) {
+ // The stream is already closed.
}
+ } else {
+ testCommitStreamObserver.onNext(request);
}
}
@Override
- public void onError(Throwable throwable) {}
+ public void onError(Throwable throwable) {
+ testCommitStreamObserver.onError(throwable);
+ }
@Override
public void onCompleted() {
- injector.cancel();
- responseObserver.onCompleted();
+ testCommitStreamObserver.onCompleted();
}
};
}
@@ -613,11 +704,50 @@ public class GrpcWindmillServerTest {
}
}
stream.flush();
- for (CountDownLatch latch : latches) {
- assertTrue(latch.await(1, TimeUnit.MINUTES));
+
+ long deadline = System.currentTimeMillis() + 60_000; // 1 min
+ while (true) {
+ Thread.sleep(100);
+ int tmpErrorsBeforeClose = errorsBeforeClose.get();
+ // wait for at least 1 errors before close
+ if (tmpErrorsBeforeClose > 0) {
+ break;
+ }
+ if (System.currentTimeMillis() > deadline) {
+ // Control should not reach here if the test is working as expected
+ fail(
+ String.format(
+ "Expected errors not sent by server errorsBeforeClose: %s"
+ + " \n Should not reach here if the test is working as expected.",
+ tmpErrorsBeforeClose));
+ }
}
stream.close();
+ isClientClosed.set(true);
+
+ deadline = System.currentTimeMillis() + 60_000; // 1 min
+ while (true) {
+ Thread.sleep(100);
+ int tmpErrorsAfterClose = errorsAfterClose.get();
+ // wait for at least 1 errors after close
+ if (tmpErrorsAfterClose > 0) {
+ break;
+ }
+ if (System.currentTimeMillis() > deadline) {
+ // Control should not reach here if the test is working as expected
+ fail(
+ String.format(
+ "Expected errors not sent by server errorsAfterClose: %s"
+ + " \n Should not reach here if the test is working as expected.",
+ tmpErrorsAfterClose));
+ }
+ }
+
+ shouldServerReturnError.set(false);
+ for (CountDownLatch latch : latches) {
+ assertTrue(latch.await(1, TimeUnit.MINUTES));
+ }
assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS));
}