You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tz...@apache.org on 2020/02/28 01:25:08 UTC

[flink-statefun] 03/05: [FLINK-16321] Add RequestReplyFunction

This is an automated email from the ASF dual-hosted git repository.

tzulitai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-statefun.git

commit 13dc42ab8d3652813dca2d30aa44a4534a9631e4
Author: Igal Shilman <ig...@gmail.com>
AuthorDate: Thu Feb 27 20:41:33 2020 +0100

    [FLINK-16321] Add RequestReplyFunction
    
    This commit factors out the RequestReply protocol out of the
    HttpFunction
---
 .../flink/core/reqreply/RequestReplyClient.java    |  28 +++
 .../flink/core/reqreply/RequestReplyFunction.java  | 255 +++++++++++++++++++++
 .../core/reqreply/RequestReplyFunctionTest.java    | 232 +++++++++++++++++++
 3 files changed, 515 insertions(+)

diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyClient.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyClient.java
new file mode 100644
index 0000000..21ad311
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyClient.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.statefun.flink.core.reqreply;
+
+import java.util.concurrent.CompletableFuture;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction;
+import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction;
+
+public interface RequestReplyClient {
+
+  CompletableFuture<FromFunction> call(ToFunction toFunction);
+}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
new file mode 100644
index 0000000..b5417c7
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.statefun.flink.core.reqreply;
+
+import static org.apache.flink.statefun.flink.core.common.PolyglotUtil.polyglotAddressToSdkAddress;
+import static org.apache.flink.statefun.flink.core.common.PolyglotUtil.sdkAddressToPolyglotAddress;
+
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
+import java.util.List;
+import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
+import org.apache.flink.statefun.flink.core.backpressure.AsyncWaiter;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.InvocationResponse;
+import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction;
+import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction.Invocation;
+import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction.InvocationBatchRequest;
+import org.apache.flink.statefun.sdk.Address;
+import org.apache.flink.statefun.sdk.AsyncOperationResult;
+import org.apache.flink.statefun.sdk.Context;
+import org.apache.flink.statefun.sdk.StatefulFunction;
+import org.apache.flink.statefun.sdk.annotations.Persisted;
+import org.apache.flink.statefun.sdk.state.PersistedAppendingBuffer;
+import org.apache.flink.statefun.sdk.state.PersistedTable;
+import org.apache.flink.statefun.sdk.state.PersistedValue;
+
+public final class RequestReplyFunction implements StatefulFunction {
+
+  private final RequestReplyClient client;
+  private final List<String> registeredStateNames;
+  private final int maxNumBatchRequests;
+
+  /**
+   * A request state keeps tracks of the number of inflight & batched requests.
+   *
+   * <p>A tracking state can have one of the following values:
+   *
+   * <ul>
+   *   <li>NULL - there is no inflight request, and there is nothing in the backlog.
+   *   <li>0 - there's an inflight request, but nothing in the backlog.
+   *   <li>{@code > 0} There is an in flight request, and @requestState items in the backlog.
+   * </ul>
+   */
+  @Persisted
+  private final PersistedValue<Integer> requestState =
+      PersistedValue.of("request-state", Integer.class);
+
+  @Persisted
+  private final PersistedAppendingBuffer<ToFunction.Invocation> batch =
+      PersistedAppendingBuffer.of("batch", ToFunction.Invocation.class);
+
+  @Persisted
+  private final PersistedTable<String, byte[]> managedStates =
+      PersistedTable.of("states", String.class, byte[].class);
+
+  public RequestReplyFunction(
+      List<String> registeredStateNames, int maxNumBatchRequests, RequestReplyClient client) {
+    this.client = Objects.requireNonNull(client);
+    this.registeredStateNames = Objects.requireNonNull(registeredStateNames);
+    this.maxNumBatchRequests = maxNumBatchRequests;
+  }
+
+  @Override
+  public void invoke(Context context, Object input) {
+    if (!(input instanceof AsyncOperationResult)) {
+      onRequest(context, (Any) input);
+      return;
+    }
+    @SuppressWarnings("unchecked")
+    AsyncOperationResult<ToFunction, FromFunction> result =
+        (AsyncOperationResult<ToFunction, FromFunction>) input;
+    onAsyncResult(context, result);
+  }
+
+  private void onRequest(Context context, Any message) {
+    Invocation.Builder invocationBuilder = singeInvocationBuilder(context, message);
+    int inflightOrBatched = requestState.getOrDefault(-1);
+    if (inflightOrBatched < 0) {
+      // no inflight requests, and nothing in the batch.
+      // so we let this request to go through, and change state to indicate that:
+      // a) there is a request in flight.
+      // b) there is nothing in the batch.
+      requestState.set(0);
+      sendToFunction(context, invocationBuilder);
+      return;
+    }
+    // there is at least one request in flight (inflightOrBatched >= 0),
+    // so we add that request to the batch.
+    batch.append(invocationBuilder.build());
+    inflightOrBatched++;
+    requestState.set(inflightOrBatched);
+    if (isMaxNumBatchRequestsExceeded(inflightOrBatched)) {
+      // we are at capacity, can't add anything to the batch.
+      // we need to signal to the runtime that we are unable to process any new input
+      // and we must wait for our in flight asynchronous operation to complete before
+      // we are able to process more input.
+      ((AsyncWaiter) context).awaitAsyncOperationComplete();
+    }
+  }
+
+  private void onAsyncResult(
+      Context context, AsyncOperationResult<ToFunction, FromFunction> asyncResult) {
+    if (asyncResult.unknown()) {
+      ToFunction batch = asyncResult.metadata();
+      sendToFunction(context, batch);
+      return;
+    }
+    InvocationResponse invocationResult = unpackInvocationOrThrow(context.self(), asyncResult);
+    handleInvocationResponse(context, invocationResult);
+
+    final int state = requestState.getOrDefault(-1);
+    if (state < 0) {
+      throw new IllegalStateException("Got an unexpected async result");
+    } else if (state == 0) {
+      requestState.clear();
+    } else {
+      final InvocationBatchRequest.Builder nextBatch = getNextBatch();
+      // an async request was just completed, but while it was in flight we have
+      // accumulated a batch, we now proceed with:
+      // a) clearing the batch from our own persisted state (the batch moves to the async operation
+      // state)
+      // b) sending the accumulated batch to the remote function.
+      requestState.set(0);
+      batch.clear();
+      sendToFunction(context, nextBatch);
+    }
+  }
+
+  private InvocationResponse unpackInvocationOrThrow(
+      Address self, AsyncOperationResult<ToFunction, FromFunction> result) {
+    if (result.failure()) {
+      throw new IllegalStateException(
+          "Failure forwarding a message to a remote function " + self, result.throwable());
+    }
+    FromFunction fromFunction = result.value();
+    if (fromFunction.hasInvocationResult()) {
+      return fromFunction.getInvocationResult();
+    }
+    return InvocationResponse.getDefaultInstance();
+  }
+
+  private InvocationBatchRequest.Builder getNextBatch() {
+    InvocationBatchRequest.Builder builder = InvocationBatchRequest.newBuilder();
+    Iterable<Invocation> view = batch.view();
+    builder.addAllInvocations(view);
+    return builder;
+  }
+
+  private void handleInvocationResponse(Context context, InvocationResponse invocationResult) {
+    for (FromFunction.Invocation invokeCommand : invocationResult.getOutgoingMessagesList()) {
+      final org.apache.flink.statefun.sdk.Address to =
+          polyglotAddressToSdkAddress(invokeCommand.getTarget());
+      final Any message = invokeCommand.getArgument();
+
+      context.send(to, message);
+    }
+    handleStateMutations(invocationResult);
+  }
+
+  // --------------------------------------------------------------------------------
+  // State Management
+  // --------------------------------------------------------------------------------
+
+  private void addStates(ToFunction.InvocationBatchRequest.Builder batchBuilder) {
+    for (String stateName : registeredStateNames) {
+      ToFunction.PersistedValue.Builder valueBuilder =
+          ToFunction.PersistedValue.newBuilder().setStateName(stateName);
+
+      byte[] stateValue = managedStates.get(stateName);
+      if (stateValue != null) {
+        valueBuilder.setStateValue(ByteString.copyFrom(stateValue));
+      }
+      batchBuilder.addState(valueBuilder);
+    }
+  }
+
+  private void handleStateMutations(InvocationResponse invocationResult) {
+    for (FromFunction.PersistedValueMutation mutate : invocationResult.getStateMutationsList()) {
+      final String stateName = mutate.getStateName();
+      switch (mutate.getMutationType()) {
+        case DELETE:
+          managedStates.remove(stateName);
+          break;
+        case MODIFY:
+          managedStates.set(stateName, mutate.getStateValue().toByteArray());
+          break;
+        case UNRECOGNIZED:
+          break;
+        default:
+          throw new IllegalStateException("Unexpected value: " + mutate.getMutationType());
+      }
+    }
+  }
+
+  // --------------------------------------------------------------------------------
+  // Send Message to Remote Function
+  // --------------------------------------------------------------------------------
+  /**
+   * Returns an {@link Invocation.Builder} set with the input {@code message} and the caller
+   * information (is present).
+   */
+  private static Invocation.Builder singeInvocationBuilder(Context context, Any message) {
+    Invocation.Builder invocationBuilder = Invocation.newBuilder();
+    if (context.caller() != null) {
+      invocationBuilder.setCaller(sdkAddressToPolyglotAddress(context.caller()));
+    }
+    invocationBuilder.setArgument(message);
+    return invocationBuilder;
+  }
+
+  /**
+   * Sends a {@link InvocationBatchRequest} to the remote function consisting out of a single
+   * invocation represented by {@code invocationBuilder}.
+   */
+  private void sendToFunction(Context context, Invocation.Builder invocationBuilder) {
+    InvocationBatchRequest.Builder batchBuilder = InvocationBatchRequest.newBuilder();
+    batchBuilder.addInvocations(invocationBuilder);
+    sendToFunction(context, batchBuilder);
+  }
+
+  /** Sends a {@link InvocationBatchRequest} to the remote function. */
+  private void sendToFunction(Context context, InvocationBatchRequest.Builder batchBuilder) {
+    batchBuilder.setTarget(sdkAddressToPolyglotAddress(context.self()));
+    addStates(batchBuilder);
+    ToFunction toFunction = ToFunction.newBuilder().setInvocation(batchBuilder).build();
+    sendToFunction(context, toFunction);
+  }
+
+  private void sendToFunction(Context context, ToFunction toFunction) {
+
+    CompletableFuture<FromFunction> responseFuture = client.call(toFunction);
+    context.registerAsyncOperation(toFunction, responseFuture);
+  }
+
+  private boolean isMaxNumBatchRequestsExceeded(final int currentNumBatchRequests) {
+    return maxNumBatchRequests > 0 && currentNumBatchRequests >= maxNumBatchRequests;
+  }
+}
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
new file mode 100644
index 0000000..0b5f4ef
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.statefun.flink.core.reqreply;
+
+import static org.apache.flink.statefun.flink.core.TestUtils.FUNCTION_1_ADDR;
+import static org.apache.flink.statefun.flink.core.common.PolyglotUtil.polyglotAddressToSdkAddress;
+import static org.hamcrest.CoreMatchers.is;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
+import java.time.Duration;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Supplier;
+import org.apache.flink.statefun.flink.core.TestUtils;
+import org.apache.flink.statefun.flink.core.backpressure.AsyncWaiter;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.InvocationResponse;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.PersistedValueMutation;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.PersistedValueMutation.MutationType;
+import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction;
+import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction.Invocation;
+import org.apache.flink.statefun.sdk.Address;
+import org.apache.flink.statefun.sdk.AsyncOperationResult;
+import org.apache.flink.statefun.sdk.AsyncOperationResult.Status;
+import org.apache.flink.statefun.sdk.Context;
+import org.apache.flink.statefun.sdk.FunctionType;
+import org.apache.flink.statefun.sdk.io.EgressIdentifier;
+import org.junit.Test;
+
+public class RequestReplyFunctionTest {
+  private static final FunctionType FN_TYPE = new FunctionType("foo", "bar");
+
+  private final FakeClient client = new FakeClient();
+  private final FakeContext context = new FakeContext();
+  private final List<String> states = Collections.singletonList("session");
+
+  private final RequestReplyFunction functionUnderTest =
+      new RequestReplyFunction(states, 10, client);
+
+  @Test
+  public void example() {
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+    assertTrue(client.wasSentToFunction.hasInvocation());
+    assertThat(client.capturedInvocationBatchSize(), is(1));
+  }
+
+  @Test
+  public void callerIsSet() {
+    context.caller = FUNCTION_1_ADDR;
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+    Invocation anInvocation = client.capturedInvocation(0);
+    Address caller = polyglotAddressToSdkAddress(anInvocation.getCaller());
+
+    assertThat(caller, is(FUNCTION_1_ADDR));
+  }
+
+  @Test
+  public void messageIsSet() {
+    Any any = Any.pack(TestUtils.DUMMY_PAYLOAD);
+
+    functionUnderTest.invoke(context, any);
+
+    assertThat(client.capturedInvocation(0).getArgument(), is(any));
+  }
+
+  @Test
+  public void batchIsAccumulatedWhileARequestIsInFlight() {
+    // send one message
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+    // the following invocations should be queued and sent as a batch
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+    // simulate a successful completion of the first operation
+    functionUnderTest.invoke(context, successfulAsyncOperation());
+
+    assertThat(client.capturedInvocationBatchSize(), is(2));
+  }
+
+  @Test
+  public void reachingABatchLimitTriggersBackpressure() {
+    RequestReplyFunction functionUnderTest = new RequestReplyFunction(states, 2, client);
+
+    // send one message
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+    // the following invocations should be queued
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+    // the following invocations should request backpressure
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+    assertThat(context.needsWaiting, is(true));
+  }
+
+  @Test
+  public void returnedMessageReleaseBackpressure() {
+    RequestReplyFunction functionUnderTest = new RequestReplyFunction(states, 2, client);
+
+    // the following invocations should cause backpressure
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+    // complete one message, should send a batch of size 3
+    context.needsWaiting = false;
+    functionUnderTest.invoke(context, successfulAsyncOperation());
+
+    // the next message should not cause backpressure.
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+    assertThat(context.needsWaiting, is(false));
+  }
+
+  @Test
+  public void stateIsModified() {
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+    // A message returned from the function
+    // that asks to put "hello" into the session state.
+    FromFunction response =
+        FromFunction.newBuilder()
+            .setInvocationResult(
+                InvocationResponse.newBuilder()
+                    .addStateMutations(
+                        PersistedValueMutation.newBuilder()
+                            .setStateValue(ByteString.copyFromUtf8("hello"))
+                            .setMutationType(MutationType.MODIFY)
+                            .setStateName("session")))
+            .build();
+
+    functionUnderTest.invoke(context, successfulAsyncOperation(response));
+
+    functionUnderTest.invoke(context, Any.getDefaultInstance());
+    assertThat(client.capturedState(0), is(ByteString.copyFromUtf8("hello")));
+  }
+
+  private static AsyncOperationResult<Object, FromFunction> successfulAsyncOperation() {
+    return new AsyncOperationResult<>(
+        new Object(), Status.SUCCESS, FromFunction.getDefaultInstance(), null);
+  }
+
+  private static AsyncOperationResult<Object, FromFunction> successfulAsyncOperation(
+      FromFunction fromFunction) {
+    return new AsyncOperationResult<>(new Object(), Status.SUCCESS, fromFunction, null);
+  }
+
+  private static final class FakeClient implements RequestReplyClient {
+    ToFunction wasSentToFunction;
+    Supplier<FromFunction> fromFunction = FromFunction::getDefaultInstance;
+
+    @Override
+    public CompletableFuture<FromFunction> call(ToFunction toFunction) {
+      this.wasSentToFunction = toFunction;
+      try {
+        return CompletableFuture.completedFuture(this.fromFunction.get());
+      } catch (Throwable t) {
+        CompletableFuture<FromFunction> failed = new CompletableFuture<>();
+        failed.completeExceptionally(t);
+        return failed;
+      }
+    }
+
+    /** return the n-th invocation sent as part of the current batch. */
+    Invocation capturedInvocation(int n) {
+      return wasSentToFunction.getInvocation().getInvocations(n);
+    }
+
+    ByteString capturedState(int n) {
+      return wasSentToFunction.getInvocation().getState(n).getStateValue();
+    }
+
+    public int capturedInvocationBatchSize() {
+      return wasSentToFunction.getInvocation().getInvocationsCount();
+    }
+  }
+
+  private static final class FakeContext implements Context, AsyncWaiter {
+
+    Address caller;
+    boolean needsWaiting;
+
+    @Override
+    public void awaitAsyncOperationComplete() {
+      needsWaiting = true;
+    }
+
+    @Override
+    public Address self() {
+      return new Address(FN_TYPE, "0");
+    }
+
+    @Override
+    public Address caller() {
+      return caller;
+    }
+
+    @Override
+    public void send(Address to, Object message) {}
+
+    @Override
+    public <T> void send(EgressIdentifier<T> egress, T message) {}
+
+    @Override
+    public void sendAfter(Duration delay, Address to, Object message) {}
+
+    @Override
+    public <M, T> void registerAsyncOperation(M metadata, CompletableFuture<T> future) {}
+  }
+}