You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tz...@apache.org on 2020/02/28 01:25:08 UTC
[flink-statefun] 03/05: [FLINK-16321] Add RequestReplyFunction
This is an automated email from the ASF dual-hosted git repository.
tzulitai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-statefun.git
commit 13dc42ab8d3652813dca2d30aa44a4534a9631e4
Author: Igal Shilman <ig...@gmail.com>
AuthorDate: Thu Feb 27 20:41:33 2020 +0100
[FLINK-16321] Add RequestReplyFunction
This commit factors out the RequestReply protocol out of the
HttpFunction
---
.../flink/core/reqreply/RequestReplyClient.java | 28 +++
.../flink/core/reqreply/RequestReplyFunction.java | 255 +++++++++++++++++++++
.../core/reqreply/RequestReplyFunctionTest.java | 232 +++++++++++++++++++
3 files changed, 515 insertions(+)
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyClient.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyClient.java
new file mode 100644
index 0000000..21ad311
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyClient.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.statefun.flink.core.reqreply;
+
+import java.util.concurrent.CompletableFuture;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction;
+import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction;
+
+public interface RequestReplyClient {
+
+ CompletableFuture<FromFunction> call(ToFunction toFunction);
+}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
new file mode 100644
index 0000000..b5417c7
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.statefun.flink.core.reqreply;
+
+import static org.apache.flink.statefun.flink.core.common.PolyglotUtil.polyglotAddressToSdkAddress;
+import static org.apache.flink.statefun.flink.core.common.PolyglotUtil.sdkAddressToPolyglotAddress;
+
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
+import java.util.List;
+import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
+import org.apache.flink.statefun.flink.core.backpressure.AsyncWaiter;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.InvocationResponse;
+import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction;
+import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction.Invocation;
+import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction.InvocationBatchRequest;
+import org.apache.flink.statefun.sdk.Address;
+import org.apache.flink.statefun.sdk.AsyncOperationResult;
+import org.apache.flink.statefun.sdk.Context;
+import org.apache.flink.statefun.sdk.StatefulFunction;
+import org.apache.flink.statefun.sdk.annotations.Persisted;
+import org.apache.flink.statefun.sdk.state.PersistedAppendingBuffer;
+import org.apache.flink.statefun.sdk.state.PersistedTable;
+import org.apache.flink.statefun.sdk.state.PersistedValue;
+
+public final class RequestReplyFunction implements StatefulFunction {
+
+ private final RequestReplyClient client;
+ private final List<String> registeredStateNames;
+ private final int maxNumBatchRequests;
+
+ /**
+ * A request state keeps tracks of the number of inflight & batched requests.
+ *
+ * <p>A tracking state can have one of the following values:
+ *
+ * <ul>
+ * <li>NULL - there is no inflight request, and there is nothing in the backlog.
+ * <li>0 - there's an inflight request, but nothing in the backlog.
+ * <li>{@code > 0} There is an in flight request, and @requestState items in the backlog.
+ * </ul>
+ */
+ @Persisted
+ private final PersistedValue<Integer> requestState =
+ PersistedValue.of("request-state", Integer.class);
+
+ @Persisted
+ private final PersistedAppendingBuffer<ToFunction.Invocation> batch =
+ PersistedAppendingBuffer.of("batch", ToFunction.Invocation.class);
+
+ @Persisted
+ private final PersistedTable<String, byte[]> managedStates =
+ PersistedTable.of("states", String.class, byte[].class);
+
+ public RequestReplyFunction(
+ List<String> registeredStateNames, int maxNumBatchRequests, RequestReplyClient client) {
+ this.client = Objects.requireNonNull(client);
+ this.registeredStateNames = Objects.requireNonNull(registeredStateNames);
+ this.maxNumBatchRequests = maxNumBatchRequests;
+ }
+
+ @Override
+ public void invoke(Context context, Object input) {
+ if (!(input instanceof AsyncOperationResult)) {
+ onRequest(context, (Any) input);
+ return;
+ }
+ @SuppressWarnings("unchecked")
+ AsyncOperationResult<ToFunction, FromFunction> result =
+ (AsyncOperationResult<ToFunction, FromFunction>) input;
+ onAsyncResult(context, result);
+ }
+
+ private void onRequest(Context context, Any message) {
+ Invocation.Builder invocationBuilder = singeInvocationBuilder(context, message);
+ int inflightOrBatched = requestState.getOrDefault(-1);
+ if (inflightOrBatched < 0) {
+ // no inflight requests, and nothing in the batch.
+ // so we let this request to go through, and change state to indicate that:
+ // a) there is a request in flight.
+ // b) there is nothing in the batch.
+ requestState.set(0);
+ sendToFunction(context, invocationBuilder);
+ return;
+ }
+ // there is at least one request in flight (inflightOrBatched >= 0),
+ // so we add that request to the batch.
+ batch.append(invocationBuilder.build());
+ inflightOrBatched++;
+ requestState.set(inflightOrBatched);
+ if (isMaxNumBatchRequestsExceeded(inflightOrBatched)) {
+ // we are at capacity, can't add anything to the batch.
+ // we need to signal to the runtime that we are unable to process any new input
+ // and we must wait for our in flight asynchronous operation to complete before
+ // we are able to process more input.
+ ((AsyncWaiter) context).awaitAsyncOperationComplete();
+ }
+ }
+
+ private void onAsyncResult(
+ Context context, AsyncOperationResult<ToFunction, FromFunction> asyncResult) {
+ if (asyncResult.unknown()) {
+ ToFunction batch = asyncResult.metadata();
+ sendToFunction(context, batch);
+ return;
+ }
+ InvocationResponse invocationResult = unpackInvocationOrThrow(context.self(), asyncResult);
+ handleInvocationResponse(context, invocationResult);
+
+ final int state = requestState.getOrDefault(-1);
+ if (state < 0) {
+ throw new IllegalStateException("Got an unexpected async result");
+ } else if (state == 0) {
+ requestState.clear();
+ } else {
+ final InvocationBatchRequest.Builder nextBatch = getNextBatch();
+ // an async request was just completed, but while it was in flight we have
+ // accumulated a batch, we now proceed with:
+ // a) clearing the batch from our own persisted state (the batch moves to the async operation
+ // state)
+ // b) sending the accumulated batch to the remote function.
+ requestState.set(0);
+ batch.clear();
+ sendToFunction(context, nextBatch);
+ }
+ }
+
+ private InvocationResponse unpackInvocationOrThrow(
+ Address self, AsyncOperationResult<ToFunction, FromFunction> result) {
+ if (result.failure()) {
+ throw new IllegalStateException(
+ "Failure forwarding a message to a remote function " + self, result.throwable());
+ }
+ FromFunction fromFunction = result.value();
+ if (fromFunction.hasInvocationResult()) {
+ return fromFunction.getInvocationResult();
+ }
+ return InvocationResponse.getDefaultInstance();
+ }
+
+ private InvocationBatchRequest.Builder getNextBatch() {
+ InvocationBatchRequest.Builder builder = InvocationBatchRequest.newBuilder();
+ Iterable<Invocation> view = batch.view();
+ builder.addAllInvocations(view);
+ return builder;
+ }
+
+ private void handleInvocationResponse(Context context, InvocationResponse invocationResult) {
+ for (FromFunction.Invocation invokeCommand : invocationResult.getOutgoingMessagesList()) {
+ final org.apache.flink.statefun.sdk.Address to =
+ polyglotAddressToSdkAddress(invokeCommand.getTarget());
+ final Any message = invokeCommand.getArgument();
+
+ context.send(to, message);
+ }
+ handleStateMutations(invocationResult);
+ }
+
+ // --------------------------------------------------------------------------------
+ // State Management
+ // --------------------------------------------------------------------------------
+
+ private void addStates(ToFunction.InvocationBatchRequest.Builder batchBuilder) {
+ for (String stateName : registeredStateNames) {
+ ToFunction.PersistedValue.Builder valueBuilder =
+ ToFunction.PersistedValue.newBuilder().setStateName(stateName);
+
+ byte[] stateValue = managedStates.get(stateName);
+ if (stateValue != null) {
+ valueBuilder.setStateValue(ByteString.copyFrom(stateValue));
+ }
+ batchBuilder.addState(valueBuilder);
+ }
+ }
+
+ private void handleStateMutations(InvocationResponse invocationResult) {
+ for (FromFunction.PersistedValueMutation mutate : invocationResult.getStateMutationsList()) {
+ final String stateName = mutate.getStateName();
+ switch (mutate.getMutationType()) {
+ case DELETE:
+ managedStates.remove(stateName);
+ break;
+ case MODIFY:
+ managedStates.set(stateName, mutate.getStateValue().toByteArray());
+ break;
+ case UNRECOGNIZED:
+ break;
+ default:
+ throw new IllegalStateException("Unexpected value: " + mutate.getMutationType());
+ }
+ }
+ }
+
+ // --------------------------------------------------------------------------------
+ // Send Message to Remote Function
+ // --------------------------------------------------------------------------------
+ /**
+ * Returns an {@link Invocation.Builder} set with the input {@code message} and the caller
+ * information (is present).
+ */
+ private static Invocation.Builder singeInvocationBuilder(Context context, Any message) {
+ Invocation.Builder invocationBuilder = Invocation.newBuilder();
+ if (context.caller() != null) {
+ invocationBuilder.setCaller(sdkAddressToPolyglotAddress(context.caller()));
+ }
+ invocationBuilder.setArgument(message);
+ return invocationBuilder;
+ }
+
+ /**
+ * Sends a {@link InvocationBatchRequest} to the remote function consisting out of a single
+ * invocation represented by {@code invocationBuilder}.
+ */
+ private void sendToFunction(Context context, Invocation.Builder invocationBuilder) {
+ InvocationBatchRequest.Builder batchBuilder = InvocationBatchRequest.newBuilder();
+ batchBuilder.addInvocations(invocationBuilder);
+ sendToFunction(context, batchBuilder);
+ }
+
+ /** Sends a {@link InvocationBatchRequest} to the remote function. */
+ private void sendToFunction(Context context, InvocationBatchRequest.Builder batchBuilder) {
+ batchBuilder.setTarget(sdkAddressToPolyglotAddress(context.self()));
+ addStates(batchBuilder);
+ ToFunction toFunction = ToFunction.newBuilder().setInvocation(batchBuilder).build();
+ sendToFunction(context, toFunction);
+ }
+
+ private void sendToFunction(Context context, ToFunction toFunction) {
+
+ CompletableFuture<FromFunction> responseFuture = client.call(toFunction);
+ context.registerAsyncOperation(toFunction, responseFuture);
+ }
+
+ private boolean isMaxNumBatchRequestsExceeded(final int currentNumBatchRequests) {
+ return maxNumBatchRequests > 0 && currentNumBatchRequests >= maxNumBatchRequests;
+ }
+}
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
new file mode 100644
index 0000000..0b5f4ef
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.statefun.flink.core.reqreply;
+
+import static org.apache.flink.statefun.flink.core.TestUtils.FUNCTION_1_ADDR;
+import static org.apache.flink.statefun.flink.core.common.PolyglotUtil.polyglotAddressToSdkAddress;
+import static org.hamcrest.CoreMatchers.is;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
+import java.time.Duration;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Supplier;
+import org.apache.flink.statefun.flink.core.TestUtils;
+import org.apache.flink.statefun.flink.core.backpressure.AsyncWaiter;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.InvocationResponse;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.PersistedValueMutation;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.PersistedValueMutation.MutationType;
+import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction;
+import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction.Invocation;
+import org.apache.flink.statefun.sdk.Address;
+import org.apache.flink.statefun.sdk.AsyncOperationResult;
+import org.apache.flink.statefun.sdk.AsyncOperationResult.Status;
+import org.apache.flink.statefun.sdk.Context;
+import org.apache.flink.statefun.sdk.FunctionType;
+import org.apache.flink.statefun.sdk.io.EgressIdentifier;
+import org.junit.Test;
+
+public class RequestReplyFunctionTest {
+ private static final FunctionType FN_TYPE = new FunctionType("foo", "bar");
+
+ private final FakeClient client = new FakeClient();
+ private final FakeContext context = new FakeContext();
+ private final List<String> states = Collections.singletonList("session");
+
+ private final RequestReplyFunction functionUnderTest =
+ new RequestReplyFunction(states, 10, client);
+
+ @Test
+ public void example() {
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+ assertTrue(client.wasSentToFunction.hasInvocation());
+ assertThat(client.capturedInvocationBatchSize(), is(1));
+ }
+
+ @Test
+ public void callerIsSet() {
+ context.caller = FUNCTION_1_ADDR;
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+ Invocation anInvocation = client.capturedInvocation(0);
+ Address caller = polyglotAddressToSdkAddress(anInvocation.getCaller());
+
+ assertThat(caller, is(FUNCTION_1_ADDR));
+ }
+
+ @Test
+ public void messageIsSet() {
+ Any any = Any.pack(TestUtils.DUMMY_PAYLOAD);
+
+ functionUnderTest.invoke(context, any);
+
+ assertThat(client.capturedInvocation(0).getArgument(), is(any));
+ }
+
+ @Test
+ public void batchIsAccumulatedWhileARequestIsInFlight() {
+ // send one message
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+ // the following invocations should be queued and sent as a batch
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+ // simulate a successful completion of the first operation
+ functionUnderTest.invoke(context, successfulAsyncOperation());
+
+ assertThat(client.capturedInvocationBatchSize(), is(2));
+ }
+
+ @Test
+ public void reachingABatchLimitTriggersBackpressure() {
+ RequestReplyFunction functionUnderTest = new RequestReplyFunction(states, 2, client);
+
+ // send one message
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+ // the following invocations should be queued
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+ // the following invocations should request backpressure
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+ assertThat(context.needsWaiting, is(true));
+ }
+
+ @Test
+ public void returnedMessageReleaseBackpressure() {
+ RequestReplyFunction functionUnderTest = new RequestReplyFunction(states, 2, client);
+
+ // the following invocations should cause backpressure
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+ // complete one message, should send a batch of size 3
+ context.needsWaiting = false;
+ functionUnderTest.invoke(context, successfulAsyncOperation());
+
+ // the next message should not cause backpressure.
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+ assertThat(context.needsWaiting, is(false));
+ }
+
+ @Test
+ public void stateIsModified() {
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+
+ // A message returned from the function
+ // that asks to put "hello" into the session state.
+ FromFunction response =
+ FromFunction.newBuilder()
+ .setInvocationResult(
+ InvocationResponse.newBuilder()
+ .addStateMutations(
+ PersistedValueMutation.newBuilder()
+ .setStateValue(ByteString.copyFromUtf8("hello"))
+ .setMutationType(MutationType.MODIFY)
+ .setStateName("session")))
+ .build();
+
+ functionUnderTest.invoke(context, successfulAsyncOperation(response));
+
+ functionUnderTest.invoke(context, Any.getDefaultInstance());
+ assertThat(client.capturedState(0), is(ByteString.copyFromUtf8("hello")));
+ }
+
+ private static AsyncOperationResult<Object, FromFunction> successfulAsyncOperation() {
+ return new AsyncOperationResult<>(
+ new Object(), Status.SUCCESS, FromFunction.getDefaultInstance(), null);
+ }
+
+ private static AsyncOperationResult<Object, FromFunction> successfulAsyncOperation(
+ FromFunction fromFunction) {
+ return new AsyncOperationResult<>(new Object(), Status.SUCCESS, fromFunction, null);
+ }
+
+ private static final class FakeClient implements RequestReplyClient {
+ ToFunction wasSentToFunction;
+ Supplier<FromFunction> fromFunction = FromFunction::getDefaultInstance;
+
+ @Override
+ public CompletableFuture<FromFunction> call(ToFunction toFunction) {
+ this.wasSentToFunction = toFunction;
+ try {
+ return CompletableFuture.completedFuture(this.fromFunction.get());
+ } catch (Throwable t) {
+ CompletableFuture<FromFunction> failed = new CompletableFuture<>();
+ failed.completeExceptionally(t);
+ return failed;
+ }
+ }
+
+ /** return the n-th invocation sent as part of the current batch. */
+ Invocation capturedInvocation(int n) {
+ return wasSentToFunction.getInvocation().getInvocations(n);
+ }
+
+ ByteString capturedState(int n) {
+ return wasSentToFunction.getInvocation().getState(n).getStateValue();
+ }
+
+ public int capturedInvocationBatchSize() {
+ return wasSentToFunction.getInvocation().getInvocationsCount();
+ }
+ }
+
+ private static final class FakeContext implements Context, AsyncWaiter {
+
+ Address caller;
+ boolean needsWaiting;
+
+ @Override
+ public void awaitAsyncOperationComplete() {
+ needsWaiting = true;
+ }
+
+ @Override
+ public Address self() {
+ return new Address(FN_TYPE, "0");
+ }
+
+ @Override
+ public Address caller() {
+ return caller;
+ }
+
+ @Override
+ public void send(Address to, Object message) {}
+
+ @Override
+ public <T> void send(EgressIdentifier<T> egress, T message) {}
+
+ @Override
+ public void sendAfter(Duration delay, Address to, Object message) {}
+
+ @Override
+ public <M, T> void registerAsyncOperation(M metadata, CompletableFuture<T> future) {}
+ }
+}