You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tz...@apache.org on 2021/02/01 05:49:48 UTC
[flink-statefun] 03/03: [FLINK-21171] Wire in TypedValue throughout
the runtime as state values and message payloads
This is an automated email from the ASF dual-hosted git repository.
tzulitai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-statefun.git
commit 69c658ec361682fb3bef50bd810f0646d7332a0c
Author: Tzu-Li (Gordon) Tai <tz...@apache.org>
AuthorDate: Thu Jan 28 16:40:16 2021 +0800
[FLINK-21171] Wire in TypedValue throughout the runtime as state values and message payloads
This closes #195.
---
statefun-e2e-tests/statefun-smoke-e2e/pom.xml | 59 +++++++++++++
.../statefun/e2e/smoke/CommandFlinkSource.java | 15 ++--
.../statefun/e2e/smoke/CommandInterpreter.java | 28 ++++---
.../flink/statefun/e2e/smoke/CommandRouter.java | 12 +--
.../apache/flink/statefun/e2e/smoke/Constants.java | 12 +--
.../apache/flink/statefun/e2e/smoke/Module.java | 9 +-
.../flink/statefun/e2e/smoke/ProtobufUtils.java | 34 --------
.../statefun/e2e/smoke/CommandInterpreterTest.java | 4 +-
.../flink/statefun/e2e/smoke/HarnessTest.java | 4 +-
.../flink/statefun/e2e/smoke/SmokeRunner.java | 4 +-
.../org/apache/flink/statefun/e2e/smoke/Utils.java | 15 ++--
.../run-example.py | 26 ++++--
statefun-flink/statefun-flink-common/pom.xml | 57 +++++++++++++
.../flink/common/types/TypedValueUtil.java | 55 ++++++++++++
.../flink/core/jsonmodule/EgressJsonEntity.java | 6 +-
.../protorouter/AutoRoutableProtobufRouter.java | 15 ++--
.../reqreply/PersistedRemoteFunctionValues.java | 37 +++++----
.../flink/core/reqreply/RequestReplyFunction.java | 16 ++--
.../flink/core/jsonmodule/JsonModuleTest.java | 5 +-
.../PersistedRemoteFunctionValuesTest.java | 51 +++++++++---
.../core/reqreply/RequestReplyFunctionTest.java | 97 +++++++++++++---------
statefun-flink/statefun-flink-io-bundle/pom.xml | 10 +++
.../io/kafka/GenericKafkaEgressSerializer.java | 15 ++--
.../flink/io/kafka/GenericKafkaSinkProvider.java | 6 +-
.../polyglot/GenericKinesisEgressSerializer.java | 13 +--
.../polyglot/GenericKinesisSinkProvider.java | 6 +-
.../io/kafka/GenericKafkaSinkProviderTest.java | 4 +-
.../io/kinesis/GenericKinesisSinkProviderTest.java | 4 +-
statefun-python-sdk/statefun/core.py | 7 ++
statefun-python-sdk/statefun/request_reply.py | 27 +++---
statefun-python-sdk/statefun/typed_value_utils.py | 49 +++++++++++
statefun-python-sdk/tests/request_reply_test.py | 34 ++++++--
.../src/main/protobuf/sdk/request-reply.proto | 21 +++--
33 files changed, 537 insertions(+), 220 deletions(-)
diff --git a/statefun-e2e-tests/statefun-smoke-e2e/pom.xml b/statefun-e2e-tests/statefun-smoke-e2e/pom.xml
index 71bb3c3..26318c2 100644
--- a/statefun-e2e-tests/statefun-smoke-e2e/pom.xml
+++ b/statefun-e2e-tests/statefun-smoke-e2e/pom.xml
@@ -30,6 +30,7 @@ under the License.
<properties>
<testcontainers.version>1.12.5</testcontainers.version>
<commons-math3.version>3.5</commons-math3.version>
+ <additional-sources.dir>target/additional-sources</additional-sources.dir>
</properties>
<dependencies>
@@ -41,6 +42,11 @@ under the License.
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
+ <artifactId>statefun-sdk-protos</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
<artifactId>statefun-flink-io</artifactId>
<version>${project.version}</version>
</dependency>
@@ -132,10 +138,63 @@ under the License.
<build>
<plugins>
+ <!--
+ The following plugin is executed in the generated-sources phase,
+ and is responsible to extract the additional *.proto files located
+ at statefun-sdk-protos.jar.
+ -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-dependency-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>unpack</id>
+ <phase>generate-sources</phase>
+ <goals>
+ <goal>unpack</goal>
+ </goals>
+ <configuration>
+ <artifactItems>
+ <artifactItem>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>statefun-sdk-protos</artifactId>
+ <version>${project.version}</version>
+ <type>jar</type>
+ <outputDirectory>${additional-sources.dir}</outputDirectory>
+ <includes>sdk/*.proto</includes>
+ </artifactItem>
+ </artifactItems>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <!--
+ The following plugin invokes protoc to generate Java classes out of the *.proto
+ definitions located at: (1) src/main/protobuf (2) ${additional-sources.dir}.
+ -->
<plugin>
<groupId>com.github.os72</groupId>
<artifactId>protoc-jar-maven-plugin</artifactId>
<version>${protoc-jar-maven-plugin.version}</version>
+ <executions>
+ <execution>
+ <id>generate-protobuf-sources</id>
+ <phase>generate-sources</phase>
+ <goals>
+ <goal>run</goal>
+ </goals>
+ <configuration>
+ <includeStdTypes>true</includeStdTypes>
+ <protocVersion>${protobuf.version}</protocVersion>
+ <cleanOutputFolder>true</cleanOutputFolder>
+ <inputDirectories>
+ <inputDirectory>src/main/protobuf</inputDirectory>
+ <inputDirectory>${additional-sources.dir}</inputDirectory>
+ </inputDirectories>
+ <outputDirectory>${basedir}/target/generated-sources/protoc-jar</outputDirectory>
+ </configuration>
+ </execution>
+ </executions>
</plugin>
</plugins>
</build>
diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandFlinkSource.java b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandFlinkSource.java
index ea4ed39..374d9e8 100644
--- a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandFlinkSource.java
+++ b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandFlinkSource.java
@@ -20,7 +20,6 @@ package org.apache.flink.statefun.e2e.smoke;
import static org.apache.flink.statefun.e2e.smoke.generated.Command.Verify;
import static org.apache.flink.statefun.e2e.smoke.generated.Command.newBuilder;
-import com.google.protobuf.Any;
import java.util.Iterator;
import java.util.Objects;
import java.util.OptionalInt;
@@ -38,6 +37,8 @@ import org.apache.flink.statefun.e2e.smoke.generated.Command;
import org.apache.flink.statefun.e2e.smoke.generated.Commands;
import org.apache.flink.statefun.e2e.smoke.generated.SourceCommand;
import org.apache.flink.statefun.e2e.smoke.generated.SourceSnapshot;
+import org.apache.flink.statefun.flink.common.types.TypedValueUtil;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
import org.slf4j.Logger;
@@ -54,7 +55,7 @@ import org.slf4j.LoggerFactory;
* to {@code verification} step. At this step, it would keep sending (every 2 seconds) a {@link
* Verify} command to every function indefinitely.
*/
-final class CommandFlinkSource extends RichSourceFunction<Any>
+final class CommandFlinkSource extends RichSourceFunction<TypedValue>
implements CheckpointedFunction, CheckpointListener {
private static final Logger LOG = LoggerFactory.getLogger(CommandFlinkSource.class);
@@ -132,7 +133,7 @@ final class CommandFlinkSource extends RichSourceFunction<Any>
// ------------------------------------------------------------------------------------------------------------
@Override
- public void run(SourceContext<Any> ctx) {
+ public void run(SourceContext<TypedValue> ctx) {
generate(ctx);
do {
verify(ctx);
@@ -145,7 +146,7 @@ final class CommandFlinkSource extends RichSourceFunction<Any>
} while (true);
}
- private void generate(SourceContext<Any> ctx) {
+ private void generate(SourceContext<TypedValue> ctx) {
final int startPosition = this.commandsSentSoFar;
final OptionalInt kaboomIndex =
computeFailureIndex(startPosition, failuresSoFar, moduleParameters.getMaxFailures());
@@ -170,13 +171,13 @@ final class CommandFlinkSource extends RichSourceFunction<Any>
return;
}
functionStateTracker.apply(command);
- ctx.collect(Any.pack(command));
+ ctx.collect(TypedValueUtil.packProtobufMessage(command));
this.commandsSentSoFar = i;
}
}
}
- private void verify(SourceContext<Any> ctx) {
+ private void verify(SourceContext<TypedValue> ctx) {
FunctionStateTracker functionStateTracker = this.functionStateTracker;
for (int i = 0; i < moduleParameters.getNumberOfFunctionInstances(); i++) {
@@ -190,7 +191,7 @@ final class CommandFlinkSource extends RichSourceFunction<Any>
.setCommands(Commands.newBuilder().addCommand(verify))
.build();
synchronized (ctx.getCheckpointLock()) {
- ctx.collect(Any.pack(command));
+ ctx.collect(TypedValueUtil.packProtobufMessage(command));
}
}
}
diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreter.java b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreter.java
index 343c8f2..036e6e0 100644
--- a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreter.java
+++ b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreter.java
@@ -17,9 +17,10 @@
*/
package org.apache.flink.statefun.e2e.smoke;
-import static org.apache.flink.statefun.e2e.smoke.ProtobufUtils.unpack;
+import static org.apache.flink.statefun.flink.common.types.TypedValueUtil.isProtobufTypeOf;
+import static org.apache.flink.statefun.flink.common.types.TypedValueUtil.packProtobufMessage;
+import static org.apache.flink.statefun.flink.common.types.TypedValueUtil.unpackProtobufMessage;
-import com.google.protobuf.Any;
import java.time.Duration;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
@@ -30,6 +31,7 @@ import org.apache.flink.statefun.e2e.smoke.generated.VerificationResult;
import org.apache.flink.statefun.sdk.AsyncOperationResult;
import org.apache.flink.statefun.sdk.Context;
import org.apache.flink.statefun.sdk.FunctionType;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.statefun.sdk.state.PersistedValue;
public final class CommandInterpreter {
@@ -50,18 +52,18 @@ public final class CommandInterpreter {
interpret(state, context, res.metadata());
return;
}
- if (!(message instanceof Any)) {
+ if (!(message instanceof TypedValue)) {
throw new IllegalArgumentException("wtf " + message);
}
- Any any = (Any) message;
- if (any.is(SourceCommand.class)) {
- SourceCommand sourceCommand = unpack(any, SourceCommand.class);
+ TypedValue typedValue = (TypedValue) message;
+ if (isProtobufTypeOf(typedValue, SourceCommand.getDescriptor())) {
+ SourceCommand sourceCommand = unpackProtobufMessage(typedValue, SourceCommand.parser());
interpret(state, context, sourceCommand.getCommands());
- } else if (any.is(Commands.class)) {
- Commands commands = unpack(any, Commands.class);
+ } else if (isProtobufTypeOf(typedValue, Commands.getDescriptor())) {
+ Commands commands = unpackProtobufMessage(typedValue, Commands.parser());
interpret(state, context, commands);
} else {
- throw new IllegalArgumentException("Unknown message type " + any.getTypeUrl());
+ throw new IllegalArgumentException("Unknown message type " + typedValue.getTypename());
}
}
@@ -96,14 +98,14 @@ public final class CommandInterpreter {
.setActual(actual)
.setExpected(expected)
.build();
- context.send(Constants.VERIFICATION_RESULT, Any.pack(verificationResult));
+ context.send(Constants.VERIFICATION_RESULT, packProtobufMessage(verificationResult));
}
private void sendEgress(
@SuppressWarnings("unused") PersistedValue<Long> state,
Context context,
@SuppressWarnings("unused") Command.SendEgress sendEgress) {
- context.send(Constants.OUT, Any.getDefaultInstance());
+ context.send(Constants.OUT, TypedValue.getDefaultInstance());
}
private void sendAfter(
@@ -112,14 +114,14 @@ public final class CommandInterpreter {
Command.SendAfter send) {
FunctionType functionType = Constants.FN_TYPE;
String id = ids.idOf(send.getTarget());
- context.sendAfter(sendAfterDelay, functionType, id, Any.pack(send.getCommands()));
+ context.sendAfter(sendAfterDelay, functionType, id, packProtobufMessage(send.getCommands()));
}
private void send(
@SuppressWarnings("unused") PersistedValue<Long> state, Context context, Command.Send send) {
FunctionType functionType = Constants.FN_TYPE;
String id = ids.idOf(send.getTarget());
- context.send(functionType, id, Any.pack(send.getCommands()));
+ context.send(functionType, id, packProtobufMessage(send.getCommands()));
}
private void registerAsyncOps(
diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandRouter.java b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandRouter.java
index e08ae8d..00af145 100644
--- a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandRouter.java
+++ b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandRouter.java
@@ -17,13 +17,14 @@
*/
package org.apache.flink.statefun.e2e.smoke;
-import com.google.protobuf.Any;
import java.util.Objects;
import org.apache.flink.statefun.e2e.smoke.generated.SourceCommand;
+import org.apache.flink.statefun.flink.common.types.TypedValueUtil;
import org.apache.flink.statefun.sdk.FunctionType;
import org.apache.flink.statefun.sdk.io.Router;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
-public class CommandRouter implements Router<Any> {
+public class CommandRouter implements Router<TypedValue> {
private final Ids ids;
public CommandRouter(Ids ids) {
@@ -31,10 +32,11 @@ public class CommandRouter implements Router<Any> {
}
@Override
- public void route(Any any, Downstream<Any> downstream) {
- SourceCommand sourceCommand = ProtobufUtils.unpack(any, SourceCommand.class);
+ public void route(TypedValue command, Downstream<TypedValue> downstream) {
+ SourceCommand sourceCommand =
+ TypedValueUtil.unpackProtobufMessage(command, SourceCommand.parser());
FunctionType type = Constants.FN_TYPE;
String id = ids.idOf(sourceCommand.getTarget());
- downstream.forward(type, id, any);
+ downstream.forward(type, id, command);
}
}
diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Constants.java b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Constants.java
index f5cf262..8f1c222 100644
--- a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Constants.java
+++ b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Constants.java
@@ -17,19 +17,21 @@
*/
package org.apache.flink.statefun.e2e.smoke;
-import com.google.protobuf.Any;
import org.apache.flink.statefun.sdk.FunctionType;
import org.apache.flink.statefun.sdk.io.EgressIdentifier;
import org.apache.flink.statefun.sdk.io.IngressIdentifier;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
public class Constants {
- public static final IngressIdentifier<Any> IN = new IngressIdentifier<>(Any.class, "", "source");
+ public static final IngressIdentifier<TypedValue> IN =
+ new IngressIdentifier<>(TypedValue.class, "", "source");
- public static final EgressIdentifier<Any> OUT = new EgressIdentifier<>("", "sink", Any.class);
+ public static final EgressIdentifier<TypedValue> OUT =
+ new EgressIdentifier<>("", "sink", TypedValue.class);
public static final FunctionType FN_TYPE = new FunctionType("v", "f1");
- public static final EgressIdentifier<Any> VERIFICATION_RESULT =
- new EgressIdentifier<>("", "verification", Any.class);
+ public static final EgressIdentifier<TypedValue> VERIFICATION_RESULT =
+ new EgressIdentifier<>("", "verification", TypedValue.class);
}
diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Module.java b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Module.java
index 21db25b..2673ac5 100644
--- a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Module.java
+++ b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Module.java
@@ -20,13 +20,13 @@ package org.apache.flink.statefun.e2e.smoke;
import static org.apache.flink.statefun.e2e.smoke.Constants.IN;
import com.google.auto.service.AutoService;
-import com.google.protobuf.Any;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Map;
import org.apache.flink.api.common.serialization.SerializationSchema;
import org.apache.flink.statefun.flink.io.datastream.SinkFunctionSpec;
import org.apache.flink.statefun.flink.io.datastream.SourceFunctionSpec;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.statefun.sdk.spi.StatefulFunctionModule;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
import org.apache.flink.streaming.api.functions.sink.SocketClientSink;
@@ -51,7 +51,7 @@ public class Module implements StatefulFunctionModule {
FunctionProvider provider = new FunctionProvider(ids);
binder.bindFunctionProvider(Constants.FN_TYPE, provider);
- SocketClientSink<Any> client =
+ SocketClientSink<TypedValue> client =
new SocketClientSink<>(
moduleParameters.getVerificationServerHost(),
moduleParameters.getVerificationServerPort(),
@@ -62,10 +62,11 @@ public class Module implements StatefulFunctionModule {
binder.bindEgress(new SinkFunctionSpec<>(Constants.VERIFICATION_RESULT, client));
}
- private static final class VerificationResultSerializer implements SerializationSchema<Any> {
+ private static final class VerificationResultSerializer
+ implements SerializationSchema<TypedValue> {
@Override
- public byte[] serialize(Any element) {
+ public byte[] serialize(TypedValue element) {
try {
ByteArrayOutputStream out = new ByteArrayOutputStream(element.getSerializedSize() + 8);
element.writeDelimitedTo(out);
diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/ProtobufUtils.java b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/ProtobufUtils.java
deleted file mode 100644
index 25aec2a..0000000
--- a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/ProtobufUtils.java
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.statefun.e2e.smoke;
-
-import com.google.protobuf.Any;
-import com.google.protobuf.InvalidProtocolBufferException;
-import com.google.protobuf.Message;
-
-final class ProtobufUtils {
- private ProtobufUtils() {}
-
- public static <T extends Message> T unpack(Any any, Class<T> messageType) {
- try {
- return any.unpack(messageType);
- } catch (InvalidProtocolBufferException e) {
- throw new IllegalStateException(e);
- }
- }
-}
diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java
index 1010666..226f418 100644
--- a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java
+++ b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java
@@ -21,10 +21,10 @@ import static org.apache.flink.statefun.e2e.smoke.Utils.aStateModificationComman
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
-import com.google.protobuf.Any;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import org.apache.flink.statefun.e2e.smoke.generated.SourceCommand;
+import org.apache.flink.statefun.flink.common.types.TypedValueUtil;
import org.apache.flink.statefun.sdk.Address;
import org.apache.flink.statefun.sdk.Context;
import org.apache.flink.statefun.sdk.io.EgressIdentifier;
@@ -41,7 +41,7 @@ public class CommandInterpreterTest {
Context context = new MockContext();
SourceCommand sourceCommand = aStateModificationCommand();
- interpreter.interpret(state, context, Any.pack(sourceCommand));
+ interpreter.interpret(state, context, TypedValueUtil.packProtobufMessage(sourceCommand));
assertThat(state.get(), is(1L));
}
diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/HarnessTest.java b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/HarnessTest.java
index 88864f8..382eefe 100644
--- a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/HarnessTest.java
+++ b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/HarnessTest.java
@@ -21,8 +21,8 @@ package org.apache.flink.statefun.e2e.smoke;
import static org.apache.flink.statefun.e2e.smoke.Utils.awaitVerificationSuccess;
import static org.apache.flink.statefun.e2e.smoke.Utils.startProtobufServer;
-import com.google.protobuf.Any;
import org.apache.flink.statefun.flink.harness.Harness;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.junit.Ignore;
import org.junit.Test;
import org.slf4j.Logger;
@@ -51,7 +51,7 @@ public class HarnessTest {
harness.withConfiguration("state.checkpoints.dir", "file:///tmp/checkpoints");
// start the Protobuf server
- SimpleProtobufServer.StartedServer<Any> started = startProtobufServer();
+ SimpleProtobufServer.StartedServer<TypedValue> started = startProtobufServer();
// configure test parameters.
ModuleParameters parameters = new ModuleParameters();
diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/SmokeRunner.java b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/SmokeRunner.java
index 9f2065e..55c857c 100644
--- a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/SmokeRunner.java
+++ b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/SmokeRunner.java
@@ -21,8 +21,8 @@ package org.apache.flink.statefun.e2e.smoke;
import static org.apache.flink.statefun.e2e.smoke.Utils.awaitVerificationSuccess;
import static org.apache.flink.statefun.e2e.smoke.Utils.startProtobufServer;
-import com.google.protobuf.Any;
import org.apache.flink.statefun.e2e.common.StatefulFunctionsAppContainers;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.util.function.ThrowingRunnable;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;
@@ -34,7 +34,7 @@ public final class SmokeRunner {
private static final Logger LOG = LoggerFactory.getLogger(SmokeRunner.class);
public static void run(ModuleParameters parameters) throws Throwable {
- SimpleProtobufServer.StartedServer<Any> server = startProtobufServer();
+ SimpleProtobufServer.StartedServer<TypedValue> server = startProtobufServer();
parameters.setVerificationServerHost("host.testcontainers.internal");
parameters.setVerificationServerPort(server.port());
diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/Utils.java b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/Utils.java
index 85f527d..ffbd57c 100644
--- a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/Utils.java
+++ b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/Utils.java
@@ -17,7 +17,6 @@
*/
package org.apache.flink.statefun.e2e.smoke;
-import com.google.protobuf.Any;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Supplier;
@@ -25,6 +24,8 @@ import org.apache.flink.statefun.e2e.smoke.generated.Command;
import org.apache.flink.statefun.e2e.smoke.generated.Commands;
import org.apache.flink.statefun.e2e.smoke.generated.SourceCommand;
import org.apache.flink.statefun.e2e.smoke.generated.VerificationResult;
+import org.apache.flink.statefun.flink.common.types.TypedValueUtil;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
class Utils {
@@ -60,11 +61,13 @@ class Utils {
}
/** Blocks the currently executing thread until enough successful verification results supply. */
- static void awaitVerificationSuccess(Supplier<Any> results, final int numberOfFunctionInstances) {
+ static void awaitVerificationSuccess(
+ Supplier<TypedValue> results, final int numberOfFunctionInstances) {
Set<Integer> successfullyVerified = new HashSet<>();
while (successfullyVerified.size() != numberOfFunctionInstances) {
- Any any = results.get();
- VerificationResult result = ProtobufUtils.unpack(any, VerificationResult.class);
+ TypedValue typedValue = results.get();
+ VerificationResult result =
+ TypedValueUtil.unpackProtobufMessage(typedValue, VerificationResult.parser());
if (result.getActual() == result.getExpected()) {
successfullyVerified.add(result.getId());
} else if (result.getActual() > result.getExpected()) {
@@ -80,8 +83,8 @@ class Utils {
}
/** starts a simple Protobuf TCP server that accepts {@link com.google.protobuf.Any}. */
- static SimpleProtobufServer.StartedServer<Any> startProtobufServer() {
- SimpleProtobufServer<Any> server = new SimpleProtobufServer<>(Any.parser());
+ static SimpleProtobufServer.StartedServer<TypedValue> startProtobufServer() {
+ SimpleProtobufServer<TypedValue> server = new SimpleProtobufServer<>(TypedValue.parser());
return server.start();
}
}
diff --git a/statefun-examples/statefun-python-walkthrough-example/run-example.py b/statefun-examples/statefun-python-walkthrough-example/run-example.py
index 3795e8f..8cee3b5 100644
--- a/statefun-examples/statefun-python-walkthrough-example/run-example.py
+++ b/statefun-examples/statefun-python-walkthrough-example/run-example.py
@@ -22,7 +22,7 @@ import requests
from google.protobuf.json_format import MessageToDict
from google.protobuf.any_pb2 import Any
-from statefun.request_reply_pb2 import ToFunction, FromFunction
+from statefun.request_reply_pb2 import ToFunction, FromFunction, TypedValue
from walkthrough_pb2 import Hello, AnotherHello, Counter
@@ -41,9 +41,7 @@ class InvocationBuilder(object):
state = self.to_function.invocation.state.add()
state.state_name = name
if value:
- any = Any()
- any.Pack(value)
- state.state_value = any.SerializeToString()
+ state.state_value.CopyFrom(self.to_typed_value_any_state(value))
return self
def with_invocation(self, arg, caller=None):
@@ -51,13 +49,31 @@ class InvocationBuilder(object):
if caller:
(ns, type, id) = caller
InvocationBuilder.set_address(ns, type, id, invocation.caller)
- invocation.argument.Pack(arg)
+ invocation.argument.CopyFrom(self.to_typed_value(arg))
return self
def SerializeToString(self):
return self.to_function.SerializeToString()
@staticmethod
+ def to_typed_value(proto_msg):
+ any = Any()
+ any.Pack(proto_msg)
+ typed_value = TypedValue()
+ typed_value.typename = any.type_url
+ typed_value.value = any.value
+ return typed_value
+
+ @staticmethod
+ def to_typed_value_any_state(proto_msg):
+ any = Any()
+ any.Pack(proto_msg)
+ typed_value = TypedValue()
+ typed_value.typename = "type.googleapis.com/google.protobuf.Any"
+ typed_value.value = any.SerializeToString()
+ return typed_value
+
+ @staticmethod
def set_address(namespace, type, id, address):
address.namespace = namespace
address.type = type
diff --git a/statefun-flink/statefun-flink-common/pom.xml b/statefun-flink/statefun-flink-common/pom.xml
index f4ef3f5..8063972 100644
--- a/statefun-flink/statefun-flink-common/pom.xml
+++ b/statefun-flink/statefun-flink-common/pom.xml
@@ -29,6 +29,10 @@ under the License.
<artifactId>statefun-flink-common</artifactId>
+ <properties>
+ <additional-sources.dir>target/additional-sources</additional-sources.dir>
+ </properties>
+
<dependencies>
<!-- flink runtime -->
<dependency>
@@ -84,10 +88,63 @@ under the License.
<build>
<plugins>
+ <!--
+ The following plugin is executed in the generated-sources phase,
+ and is responsible to extract the additional *.proto files located
+ at statefun-sdk-protos.jar.
+ -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-dependency-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>unpack</id>
+ <phase>generate-sources</phase>
+ <goals>
+ <goal>unpack</goal>
+ </goals>
+ <configuration>
+ <artifactItems>
+ <artifactItem>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>statefun-sdk-protos</artifactId>
+ <version>${project.version}</version>
+ <type>jar</type>
+ <outputDirectory>${additional-sources.dir}</outputDirectory>
+ <includes>sdk/*.proto</includes>
+ </artifactItem>
+ </artifactItems>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <!--
+ The following plugin invokes protoc to generate Java classes out of the *.proto
+ definitions located at: (1) src/main/protobuf (2) ${additional-sources.dir}.
+ -->
<plugin>
<groupId>com.github.os72</groupId>
<artifactId>protoc-jar-maven-plugin</artifactId>
<version>${protoc-jar-maven-plugin.version}</version>
+ <executions>
+ <execution>
+ <id>generate-protobuf-sources</id>
+ <phase>generate-sources</phase>
+ <goals>
+ <goal>run</goal>
+ </goals>
+ <configuration>
+ <includeStdTypes>true</includeStdTypes>
+ <protocVersion>${protobuf.version}</protocVersion>
+ <cleanOutputFolder>true</cleanOutputFolder>
+ <inputDirectories>
+ <inputDirectory>src/main/protobuf</inputDirectory>
+ <inputDirectory>${additional-sources.dir}</inputDirectory>
+ </inputDirectories>
+ <outputDirectory>${basedir}/target/generated-sources/protoc-jar</outputDirectory>
+ </configuration>
+ </execution>
+ </executions>
</plugin>
</plugins>
</build>
diff --git a/statefun-flink/statefun-flink-common/src/main/java/org/apache/flink/statefun/flink/common/types/TypedValueUtil.java b/statefun-flink/statefun-flink-common/src/main/java/org/apache/flink/statefun/flink/common/types/TypedValueUtil.java
new file mode 100644
index 0000000..38f9808
--- /dev/null
+++ b/statefun-flink/statefun-flink-common/src/main/java/org/apache/flink/statefun/flink/common/types/TypedValueUtil.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.statefun.flink.common.types;
+
+import com.google.protobuf.Descriptors;
+import com.google.protobuf.InvalidProtocolBufferException;
+import com.google.protobuf.Message;
+import com.google.protobuf.Parser;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
+
+public final class TypedValueUtil {
+
+ private TypedValueUtil() {}
+
+ public static boolean isProtobufTypeOf(
+ TypedValue typedValue, Descriptors.Descriptor messageDescriptor) {
+ return typedValue.getTypename().equals(protobufTypeUrl(messageDescriptor));
+ }
+
+ public static TypedValue packProtobufMessage(Message protobufMessage) {
+ return TypedValue.newBuilder()
+ .setTypename(protobufTypeUrl(protobufMessage.getDescriptorForType()))
+ .setValue(protobufMessage.toByteString())
+ .build();
+ }
+
+ public static <PB extends Message> PB unpackProtobufMessage(
+ TypedValue typedValue, Parser<PB> protobufMessageParser) {
+ try {
+ return protobufMessageParser.parseFrom(typedValue.getValue());
+ } catch (InvalidProtocolBufferException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ private static String protobufTypeUrl(Descriptors.Descriptor messageDescriptor) {
+ return "type.googleapis.com/" + messageDescriptor.getFullName();
+ }
+}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/jsonmodule/EgressJsonEntity.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/jsonmodule/EgressJsonEntity.java
index 813b740..d3040b7 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/jsonmodule/EgressJsonEntity.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/jsonmodule/EgressJsonEntity.java
@@ -18,7 +18,6 @@
package org.apache.flink.statefun.flink.core.jsonmodule;
-import com.google.protobuf.Any;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonPointer;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
import org.apache.flink.statefun.flink.common.json.NamespaceNamePair;
@@ -26,6 +25,7 @@ import org.apache.flink.statefun.flink.common.json.Selectors;
import org.apache.flink.statefun.flink.io.spi.JsonEgressSpec;
import org.apache.flink.statefun.sdk.EgressType;
import org.apache.flink.statefun.sdk.io.EgressIdentifier;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.statefun.sdk.spi.StatefulFunctionModule.Binder;
final class EgressJsonEntity implements JsonEntity {
@@ -55,9 +55,9 @@ final class EgressJsonEntity implements JsonEntity {
return new EgressType(nn.namespace(), nn.name());
}
- private static EgressIdentifier<Any> egressId(JsonNode spec) {
+ private static EgressIdentifier<TypedValue> egressId(JsonNode spec) {
String egressId = Selectors.textAt(spec, MetaPointers.ID);
NamespaceNamePair nn = NamespaceNamePair.from(egressId);
- return new EgressIdentifier<>(nn.namespace(), nn.name(), Any.class);
+ return new EgressIdentifier<>(nn.namespace(), nn.name(), TypedValue.class);
}
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/protorouter/AutoRoutableProtobufRouter.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/protorouter/AutoRoutableProtobufRouter.java
index eb37fe8..d5e0347 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/protorouter/AutoRoutableProtobufRouter.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/protorouter/AutoRoutableProtobufRouter.java
@@ -18,7 +18,6 @@
package org.apache.flink.statefun.flink.core.protorouter;
-import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;
import org.apache.flink.statefun.flink.io.generated.AutoRoutable;
@@ -26,15 +25,21 @@ import org.apache.flink.statefun.flink.io.generated.RoutingConfig;
import org.apache.flink.statefun.flink.io.generated.TargetFunctionType;
import org.apache.flink.statefun.sdk.FunctionType;
import org.apache.flink.statefun.sdk.io.Router;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
/**
* A {@link Router} that recognizes messages of type {@link AutoRoutable}.
*
* <p>For each incoming {@code AutoRoutable}, this router forwards the wrapped payload to the
- * configured target addresses as a Protobuf {@link Any} message.
+ * configured target addresses as a {@link TypedValue} message.
*/
public final class AutoRoutableProtobufRouter implements Router<Message> {
+ /**
+ * Note: while the input and type of this method is both {@link Message}, we actually do a
+ * conversion here. The input {@link Message} is an {@link AutoRoutable}, which gets converted to
+ * a {@link TypedValue} as the output after slicing the target address and actual payload.
+ */
@Override
public void route(Message message, Downstream<Message> downstream) {
final AutoRoutable routable = asAutoRoutable(message);
@@ -43,7 +48,7 @@ public final class AutoRoutableProtobufRouter implements Router<Message> {
downstream.forward(
sdkFunctionType(targetFunction),
routable.getId(),
- anyPayload(config.getTypeUrl(), routable.getPayloadBytes()));
+ typedValuePayload(config.getTypeUrl(), routable.getPayloadBytes()));
}
}
@@ -60,7 +65,7 @@ public final class AutoRoutableProtobufRouter implements Router<Message> {
return new FunctionType(targetFunctionType.getNamespace(), targetFunctionType.getType());
}
- private static Any anyPayload(String typeUrl, ByteString payloadBytes) {
- return Any.newBuilder().setTypeUrl(typeUrl).setValue(payloadBytes).build();
+ private static TypedValue typedValuePayload(String typeUrl, ByteString payloadBytes) {
+ return TypedValue.newBuilder().setTypename(typeUrl).setValue(payloadBytes).build();
}
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java
index 42cffbe..c47c2ac 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java
@@ -31,6 +31,7 @@ import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedVa
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedValueSpec;
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction;
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction.InvocationBatchRequest;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.statefun.sdk.state.Expiration;
import org.apache.flink.statefun.sdk.state.PersistedStateRegistry;
import org.apache.flink.statefun.sdk.state.RemotePersistedValue;
@@ -48,9 +49,15 @@ public final class PersistedRemoteFunctionValues {
final ToFunction.PersistedValue.Builder valueBuilder =
ToFunction.PersistedValue.newBuilder().setStateName(managedStateEntry.getKey());
- final byte[] stateValue = managedStateEntry.getValue().get();
- if (stateValue != null) {
- valueBuilder.setStateValue(ByteString.copyFrom(stateValue));
+ final RemotePersistedValue registeredHandle = managedStateEntry.getValue();
+ final byte[] stateBytes = registeredHandle.get();
+ if (stateBytes != null) {
+ final TypedValue stateValue =
+ TypedValue.newBuilder()
+ .setValue(ByteString.copyFrom(stateBytes))
+ .setTypename(registeredHandle.type().toString())
+ .build();
+ valueBuilder.setStateValue(stateValue);
}
batchBuilder.addState(valueBuilder);
}
@@ -67,7 +74,11 @@ public final class PersistedRemoteFunctionValues {
}
case MODIFY:
{
- getStateHandleOrThrow(stateName).set(mutate.getStateValue().toByteArray());
+ final RemotePersistedValue registeredHandle = getStateHandleOrThrow(stateName);
+ final TypedValue newStateValue = mutate.getStateValue();
+
+ validateType(registeredHandle, newStateValue.getTypename());
+ registeredHandle.set(newStateValue.getValue().toByteArray());
break;
}
case UNRECOGNIZED:
@@ -102,7 +113,7 @@ public final class PersistedRemoteFunctionValues {
if (stateHandle == null) {
registerValueState(protocolPersistedValueSpec);
} else {
- validateType(stateHandle, protocolPersistedValueSpec);
+ validateType(stateHandle, protocolPersistedValueSpec.getTypeTypename());
}
}
@@ -112,7 +123,7 @@ public final class PersistedRemoteFunctionValues {
final RemotePersistedValue remoteValueState =
RemotePersistedValue.of(
stateName,
- sdkStateType(protocolPersistedValueSpec),
+ sdkStateType(protocolPersistedValueSpec.getTypeTypename()),
sdkTtlExpiration(protocolPersistedValueSpec.getExpirationSpec()));
managedStates.put(stateName, remoteValueState);
@@ -125,23 +136,21 @@ public final class PersistedRemoteFunctionValues {
}
private void validateType(
- RemotePersistedValue previousStateHandle, PersistedValueSpec protocolPersistedValueSpec) {
- final TypeName newStateType = sdkStateType(protocolPersistedValueSpec);
+ RemotePersistedValue previousStateHandle, String protocolTypenameString) {
+ final TypeName newStateType = sdkStateType(protocolTypenameString);
if (!newStateType.equals(previousStateHandle.type())) {
throw new RemoteFunctionStateException(
- protocolPersistedValueSpec.getStateName(),
+ previousStateHandle.name(),
new RemoteValueTypeMismatchException(previousStateHandle.type(), newStateType));
}
}
- private static TypeName sdkStateType(PersistedValueSpec protocolPersistedValueSpec) {
- final String typeStringPair = protocolPersistedValueSpec.getTypeTypename();
-
+ private static TypeName sdkStateType(String protocolTypenameString) {
// TODO type field may be empty in current master only because SDKs are not yet updated;
// TODO once SDKs are updated, we should expect that the type is always specified
- return protocolPersistedValueSpec.getTypeTypename().isEmpty()
+ return protocolTypenameString.isEmpty()
? UNSET_STATE_TYPE
- : TypeName.parseFrom(typeStringPair);
+ : TypeName.parseFrom(protocolTypenameString);
}
private static Expiration sdkTtlExpiration(ExpirationSpec protocolExpirationSpec) {
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
index 51db78c..a577bb0 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
@@ -21,7 +21,6 @@ package org.apache.flink.statefun.flink.core.reqreply;
import static org.apache.flink.statefun.flink.core.common.PolyglotUtil.polyglotAddressToSdkAddress;
import static org.apache.flink.statefun.flink.core.common.PolyglotUtil.sdkAddressToPolyglotAddress;
-import com.google.protobuf.Any;
import java.time.Duration;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
@@ -41,6 +40,7 @@ import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.InvocationR
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction;
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction.Invocation;
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction.InvocationBatchRequest;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.statefun.sdk.state.PersistedAppendingBuffer;
import org.apache.flink.statefun.sdk.state.PersistedValue;
import org.apache.flink.types.Either;
@@ -87,7 +87,7 @@ public final class RequestReplyFunction implements StatefulFunction {
public void invoke(Context context, Object input) {
InternalContext castedContext = (InternalContext) context;
if (!(input instanceof AsyncOperationResult)) {
- onRequest(castedContext, (Any) input);
+ onRequest(castedContext, (TypedValue) input);
return;
}
@SuppressWarnings("unchecked")
@@ -96,7 +96,7 @@ public final class RequestReplyFunction implements StatefulFunction {
onAsyncResult(castedContext, result);
}
- private void onRequest(InternalContext context, Any message) {
+ private void onRequest(InternalContext context, TypedValue message) {
Invocation.Builder invocationBuilder = singeInvocationBuilder(context, message);
int inflightOrBatched = requestState.getOrDefault(-1);
if (inflightOrBatched < 0) {
@@ -208,9 +208,9 @@ public final class RequestReplyFunction implements StatefulFunction {
private void handleEgressMessages(Context context, InvocationResponse invocationResult) {
for (EgressMessage egressMessage : invocationResult.getOutgoingEgressesList()) {
- EgressIdentifier<Any> id =
+ EgressIdentifier<TypedValue> id =
new EgressIdentifier<>(
- egressMessage.getEgressNamespace(), egressMessage.getEgressType(), Any.class);
+ egressMessage.getEgressNamespace(), egressMessage.getEgressType(), TypedValue.class);
context.send(id, egressMessage.getArgument());
}
}
@@ -218,7 +218,7 @@ public final class RequestReplyFunction implements StatefulFunction {
private void handleOutgoingMessages(Context context, InvocationResponse invocationResult) {
for (FromFunction.Invocation invokeCommand : invocationResult.getOutgoingMessagesList()) {
final Address to = polyglotAddressToSdkAddress(invokeCommand.getTarget());
- final Any message = invokeCommand.getArgument();
+ final TypedValue message = invokeCommand.getArgument();
context.send(to, message);
}
@@ -228,7 +228,7 @@ public final class RequestReplyFunction implements StatefulFunction {
for (FromFunction.DelayedInvocation delayedInvokeCommand :
invocationResult.getDelayedInvocationsList()) {
final Address to = polyglotAddressToSdkAddress(delayedInvokeCommand.getTarget());
- final Any message = delayedInvokeCommand.getArgument();
+ final TypedValue message = delayedInvokeCommand.getArgument();
final long delay = delayedInvokeCommand.getDelayInMs();
context.sendAfter(Duration.ofMillis(delay), to, message);
@@ -242,7 +242,7 @@ public final class RequestReplyFunction implements StatefulFunction {
* Returns an {@link Invocation.Builder} set with the input {@code message} and the caller
* information (is present).
*/
- private static Invocation.Builder singeInvocationBuilder(Context context, Any message) {
+ private static Invocation.Builder singeInvocationBuilder(Context context, TypedValue message) {
Invocation.Builder invocationBuilder = Invocation.newBuilder();
if (context.caller() != null) {
invocationBuilder.setCaller(sdkAddressToPolyglotAddress(context.caller()));
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java
index 92014a9..cc928b0 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java
@@ -24,7 +24,6 @@ import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertThat;
-import com.google.protobuf.Any;
import com.google.protobuf.Message;
import java.net.URL;
import java.util.Collections;
@@ -35,6 +34,7 @@ import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
import org.apache.flink.statefun.sdk.FunctionType;
import org.apache.flink.statefun.sdk.io.EgressIdentifier;
import org.apache.flink.statefun.sdk.io.IngressIdentifier;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.statefun.sdk.spi.StatefulFunctionModule;
import org.junit.Test;
@@ -97,7 +97,8 @@ public class JsonModuleTest {
module.configure(Collections.emptyMap(), universe);
assertThat(
- universe.egress(), hasKey(new EgressIdentifier<>("com.mycomp.foo", "bar", Any.class)));
+ universe.egress(),
+ hasKey(new EgressIdentifier<>("com.mycomp.foo", "bar", TypedValue.class)));
}
private static StatefulFunctionModule fromPath(String path) {
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValuesTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValuesTest.java
index b5f2927..81ab98a 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValuesTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValuesTest.java
@@ -31,6 +31,7 @@ import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedVa
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedValueSpec;
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction.InvocationBatchRequest;
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction.PersistedValue;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.junit.Test;
public class PersistedRemoteFunctionValuesTest {
@@ -50,8 +51,11 @@ public class PersistedRemoteFunctionValuesTest {
// --- update state values
values.updateStateValues(
Arrays.asList(
- protocolPersistedValueModifyMutation("state-1", ByteString.copyFromUtf8("data-1")),
- protocolPersistedValueModifyMutation("state-2", ByteString.copyFromUtf8("data-2"))));
+ protocolPersistedValueModifyMutation(
+ "state-1", protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data-1"))),
+ protocolPersistedValueModifyMutation(
+ "state-2",
+ protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data-2")))));
final InvocationBatchRequest.Builder builder = InvocationBatchRequest.newBuilder();
values.attachStateValues(builder);
@@ -61,8 +65,11 @@ public class PersistedRemoteFunctionValuesTest {
assertThat(
builder.getStateList(),
hasItems(
- protocolPersistedValue("state-1", ByteString.copyFromUtf8("data-1")),
- protocolPersistedValue("state-2", ByteString.copyFromUtf8("data-2"))));
+ protocolPersistedValue(
+ "state-1", protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data-1"))),
+ protocolPersistedValue(
+ "state-2",
+ protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data-2")))));
}
@Test
@@ -82,7 +89,8 @@ public class PersistedRemoteFunctionValuesTest {
values.updateStateValues(
Collections.singletonList(
protocolPersistedValueModifyMutation(
- "non-registered-state", ByteString.copyFromUtf8("data"))));
+ "non-registered-state",
+ protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data")))));
}
@Test
@@ -109,7 +117,8 @@ public class PersistedRemoteFunctionValuesTest {
// modify and then delete state value
values.updateStateValues(
Collections.singletonList(
- protocolPersistedValueModifyMutation("state", ByteString.copyFromUtf8("data"))));
+ protocolPersistedValueModifyMutation(
+ "state", protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data")))));
values.updateStateValues(
Collections.singletonList(protocolPersistedValueDeleteMutation("state")));
@@ -128,7 +137,8 @@ public class PersistedRemoteFunctionValuesTest {
Collections.singletonList(protocolPersistedValueSpec("state", TEST_STATE_TYPE)));
values.updateStateValues(
Collections.singletonList(
- protocolPersistedValueModifyMutation("state", ByteString.copyFromUtf8("data"))));
+ protocolPersistedValueModifyMutation(
+ "state", protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data")))));
// duplicate registration under the same state name
values.registerStates(
@@ -140,7 +150,9 @@ public class PersistedRemoteFunctionValuesTest {
assertThat(builder.getStateList().size(), is(1));
assertThat(
builder.getStateList(),
- hasItems(protocolPersistedValue("state", ByteString.copyFromUtf8("data"))));
+ hasItems(
+ protocolPersistedValue(
+ "state", protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data")))));
}
@Test(expected = RemoteFunctionStateException.class)
@@ -155,6 +167,25 @@ public class PersistedRemoteFunctionValuesTest {
protocolPersistedValueSpec("state", TypeName.parseFrom("com.foo.bar/type-2"))));
}
+ @Test(expected = RemoteFunctionStateException.class)
+ public void mutatingStateValueWithMismatchingType() {
+ final PersistedRemoteFunctionValues values = new PersistedRemoteFunctionValues();
+
+ values.registerStates(
+ Collections.singletonList(
+ protocolPersistedValueSpec("state", TypeName.parseFrom("com.foo.bar/type-1"))));
+ values.updateStateValues(
+ Collections.singletonList(
+ protocolPersistedValueModifyMutation(
+ "state",
+ protocolTypedValue(
+ TypeName.parseFrom("com.foo.bar/type-2"), ByteString.copyFromUtf8("data")))));
+ }
+
+ private static TypedValue protocolTypedValue(TypeName typename, ByteString value) {
+ return TypedValue.newBuilder().setTypename(typename.toString()).setValue(value).build();
+ }
+
private static PersistedValueSpec protocolPersistedValueSpec(String stateName, TypeName type) {
return PersistedValueSpec.newBuilder()
.setStateName(stateName)
@@ -163,7 +194,7 @@ public class PersistedRemoteFunctionValuesTest {
}
private static PersistedValueMutation protocolPersistedValueModifyMutation(
- String stateName, ByteString modifyValue) {
+ String stateName, TypedValue modifyValue) {
return PersistedValueMutation.newBuilder()
.setStateName(stateName)
.setMutationType(PersistedValueMutation.MutationType.MODIFY)
@@ -178,7 +209,7 @@ public class PersistedRemoteFunctionValuesTest {
.build();
}
- private static PersistedValue protocolPersistedValue(String stateName, ByteString stateValue) {
+ private static PersistedValue protocolPersistedValue(String stateName, TypedValue stateValue) {
final PersistedValue.Builder builder = PersistedValue.newBuilder();
builder.setStateName(stateName);
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
index 9b5d9c9..5b3a053 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
@@ -26,7 +26,6 @@ import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
-import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import java.time.Duration;
import java.util.AbstractMap.SimpleImmutableEntry;
@@ -38,7 +37,6 @@ import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
import java.util.stream.Collectors;
-import org.apache.flink.statefun.flink.core.TestUtils;
import org.apache.flink.statefun.flink.core.backpressure.InternalContext;
import org.apache.flink.statefun.flink.core.metrics.FunctionTypeMetrics;
import org.apache.flink.statefun.flink.core.metrics.RemoteInvocationMetrics;
@@ -58,6 +56,7 @@ import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedVa
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedValueSpec;
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction;
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction.Invocation;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.junit.Test;
public class RequestReplyFunctionTest {
@@ -67,11 +66,12 @@ public class RequestReplyFunctionTest {
private final FakeContext context = new FakeContext();
private final RequestReplyFunction functionUnderTest =
- new RequestReplyFunction(testInitialRegisteredState("session"), 10, client);
+ new RequestReplyFunction(
+ testInitialRegisteredState("session", "com.foo.bar/myType"), 10, client);
@Test
public void example() {
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
assertTrue(client.wasSentToFunction.hasInvocation());
assertThat(client.capturedInvocationBatchSize(), is(1));
@@ -80,7 +80,7 @@ public class RequestReplyFunctionTest {
@Test
public void callerIsSet() {
context.caller = FUNCTION_1_ADDR;
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
Invocation anInvocation = client.capturedInvocation(0);
Address caller = polyglotAddressToSdkAddress(anInvocation.getCaller());
@@ -90,20 +90,24 @@ public class RequestReplyFunctionTest {
@Test
public void messageIsSet() {
- Any any = Any.pack(TestUtils.DUMMY_PAYLOAD);
+ TypedValue argument =
+ TypedValue.newBuilder()
+ .setTypename("io.statefun.foo/bar")
+ .setValue(ByteString.copyFromUtf8("Hello!"))
+ .build();
- functionUnderTest.invoke(context, any);
+ functionUnderTest.invoke(context, argument);
- assertThat(client.capturedInvocation(0).getArgument(), is(any));
+ assertThat(client.capturedInvocation(0).getArgument(), is(argument));
}
@Test
public void batchIsAccumulatedWhileARequestIsInFlight() {
// send one message
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// the following invocations should be queued and sent as a batch
- functionUnderTest.invoke(context, Any.getDefaultInstance());
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// simulate a successful completion of the first operation
functionUnderTest.invoke(context, successfulAsyncOperation());
@@ -116,13 +120,13 @@ public class RequestReplyFunctionTest {
RequestReplyFunction functionUnderTest = new RequestReplyFunction(2, client);
// send one message
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// the following invocations should be queued
- functionUnderTest.invoke(context, Any.getDefaultInstance());
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// the following invocations should request backpressure
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
assertThat(context.needsWaiting, is(true));
}
@@ -132,24 +136,24 @@ public class RequestReplyFunctionTest {
RequestReplyFunction functionUnderTest = new RequestReplyFunction(2, client);
// the following invocations should cause backpressure
- functionUnderTest.invoke(context, Any.getDefaultInstance());
- functionUnderTest.invoke(context, Any.getDefaultInstance());
- functionUnderTest.invoke(context, Any.getDefaultInstance());
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// complete one message, should send a batch of size 3
context.needsWaiting = false;
functionUnderTest.invoke(context, successfulAsyncOperation());
// the next message should not cause backpressure.
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
assertThat(context.needsWaiting, is(false));
}
@Test
public void stateIsModified() {
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// A message returned from the function
// that asks to put "hello" into the session state.
@@ -159,20 +163,23 @@ public class RequestReplyFunctionTest {
InvocationResponse.newBuilder()
.addStateMutations(
PersistedValueMutation.newBuilder()
- .setStateValue(ByteString.copyFromUtf8("hello"))
+ .setStateValue(
+ TypedValue.newBuilder()
+ .setTypename("com.foo.bar/myType")
+ .setValue(ByteString.copyFromUtf8("hello")))
.setMutationType(MutationType.MODIFY)
.setStateName("session")))
.build();
functionUnderTest.invoke(context, successfulAsyncOperation(response));
- functionUnderTest.invoke(context, Any.getDefaultInstance());
- assertThat(client.capturedState(0), is(ByteString.copyFromUtf8("hello")));
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
+ assertThat(client.capturedState(0).getValue(), is(ByteString.copyFromUtf8("hello")));
}
@Test
public void delayedMessages() {
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
FromFunction response =
FromFunction.newBuilder()
@@ -180,7 +187,7 @@ public class RequestReplyFunctionTest {
InvocationResponse.newBuilder()
.addDelayedInvocations(
DelayedInvocation.newBuilder()
- .setArgument(Any.getDefaultInstance())
+ .setArgument(TypedValue.getDefaultInstance())
.setDelayInMs(1)
.build()))
.build();
@@ -193,7 +200,7 @@ public class RequestReplyFunctionTest {
@Test
public void egressIsSent() {
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
FromFunction response =
FromFunction.newBuilder()
@@ -201,7 +208,7 @@ public class RequestReplyFunctionTest {
InvocationResponse.newBuilder()
.addOutgoingEgresses(
EgressMessage.newBuilder()
- .setArgument(Any.getDefaultInstance())
+ .setArgument(TypedValue.getDefaultInstance())
.setEgressNamespace("org.foo")
.setEgressType("bar")))
.build();
@@ -210,13 +217,18 @@ public class RequestReplyFunctionTest {
assertFalse(context.egresses.isEmpty());
assertEquals(
- new EgressIdentifier<>("org.foo", "bar", Any.class), context.egresses.get(0).getKey());
+ new EgressIdentifier<>("org.foo", "bar", TypedValue.class),
+ context.egresses.get(0).getKey());
}
@Test
public void retryBatchOnIncompleteInvocationContextResponse() {
- Any any = Any.pack(TestUtils.DUMMY_PAYLOAD);
- functionUnderTest.invoke(context, any);
+ TypedValue argument =
+ TypedValue.newBuilder()
+ .setTypename("io.statefun.foo/bar")
+ .setValue(ByteString.copyFromUtf8("Hello!"))
+ .build();
+ functionUnderTest.invoke(context, argument);
FromFunction response =
FromFunction.newBuilder()
@@ -237,7 +249,7 @@ public class RequestReplyFunctionTest {
// re-sent batch should have identical invocation input messages
assertTrue(client.wasSentToFunction.hasInvocation());
assertThat(client.capturedInvocationBatchSize(), is(1));
- assertThat(client.capturedInvocation(0).getArgument(), is(any));
+ assertThat(client.capturedInvocation(0).getArgument(), is(argument));
// re-sent batch should have new state as well as originally registered state
assertThat(client.capturedStateNames().size(), is(2));
@@ -246,22 +258,22 @@ public class RequestReplyFunctionTest {
@Test
public void backlogMetricsIncreasedOnInvoke() {
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// following should be accounted into backlog metrics
- functionUnderTest.invoke(context, Any.getDefaultInstance());
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
assertThat(context.functionTypeMetrics().numBacklog, is(2));
}
@Test
public void backlogMetricsDecreasedOnNextSuccess() {
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// following should be accounted into backlog metrics
- functionUnderTest.invoke(context, Any.getDefaultInstance());
- functionUnderTest.invoke(context, Any.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
+ functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// complete one message, should fully consume backlog
context.needsWaiting = false;
@@ -271,11 +283,14 @@ public class RequestReplyFunctionTest {
}
private static PersistedRemoteFunctionValues testInitialRegisteredState(
- String existingStateName) {
+ String existingStateName, String typename) {
final PersistedRemoteFunctionValues states = new PersistedRemoteFunctionValues();
states.registerStates(
Collections.singletonList(
- PersistedValueSpec.newBuilder().setStateName(existingStateName).build()));
+ PersistedValueSpec.newBuilder()
+ .setTypeTypename(typename)
+ .setStateName(existingStateName)
+ .build()));
return states;
}
@@ -318,7 +333,7 @@ public class RequestReplyFunctionTest {
return wasSentToFunction.getInvocation().getInvocations(n);
}
- ByteString capturedState(int n) {
+ TypedValue capturedState(int n) {
return wasSentToFunction.getInvocation().getState(n).getStateValue();
}
diff --git a/statefun-flink/statefun-flink-io-bundle/pom.xml b/statefun-flink/statefun-flink-io-bundle/pom.xml
index 51acb36..251955a 100644
--- a/statefun-flink/statefun-flink-io-bundle/pom.xml
+++ b/statefun-flink/statefun-flink-io-bundle/pom.xml
@@ -29,6 +29,10 @@ under the License.
<artifactId>statefun-flink-io-bundle</artifactId>
+ <properties>
+ <additional-sources.dir>target/additional-sources</additional-sources.dir>
+ </properties>
+
<dependencies>
<!-- Stateful Functions sdk -->
<dependency>
@@ -37,6 +41,12 @@ under the License.
<version>${project.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>statefun-sdk-protos</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+
<!-- statefun-flink spi -->
<dependency>
<groupId>org.apache.flink</groupId>
diff --git a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaEgressSerializer.java b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaEgressSerializer.java
index fb8a484..c232ba3 100644
--- a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaEgressSerializer.java
+++ b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaEgressSerializer.java
@@ -17,11 +17,12 @@
*/
package org.apache.flink.statefun.flink.io.kafka;
-import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.charset.StandardCharsets;
+import org.apache.flink.statefun.flink.common.types.TypedValueUtil;
import org.apache.flink.statefun.sdk.egress.generated.KafkaProducerRecord;
import org.apache.flink.statefun.sdk.kafka.KafkaEgressSerializer;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.kafka.clients.producer.ProducerRecord;
/**
@@ -31,24 +32,24 @@ import org.apache.kafka.clients.producer.ProducerRecord;
* <p>This serializer expects Protobuf messages of type {@link KafkaProducerRecord}, and simply
* transforms those into Kafka's {@link ProducerRecord}.
*/
-public final class GenericKafkaEgressSerializer implements KafkaEgressSerializer<Any> {
+public final class GenericKafkaEgressSerializer implements KafkaEgressSerializer<TypedValue> {
private static final long serialVersionUID = 1L;
@Override
- public ProducerRecord<byte[], byte[]> serialize(Any any) {
- KafkaProducerRecord protobufProducerRecord = asKafkaProducerRecord(any);
+ public ProducerRecord<byte[], byte[]> serialize(TypedValue message) {
+ KafkaProducerRecord protobufProducerRecord = asKafkaProducerRecord(message);
return toProducerRecord(protobufProducerRecord);
}
- private static KafkaProducerRecord asKafkaProducerRecord(Any message) {
- if (!message.is(KafkaProducerRecord.class)) {
+ private static KafkaProducerRecord asKafkaProducerRecord(TypedValue message) {
+ if (!TypedValueUtil.isProtobufTypeOf(message, KafkaProducerRecord.getDescriptor())) {
throw new IllegalStateException(
"The generic Kafka egress expects only messages of type "
+ KafkaProducerRecord.class.getName());
}
try {
- return message.unpack(KafkaProducerRecord.class);
+ return KafkaProducerRecord.parseFrom(message.getValue());
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(
"Unable to unpack message as a " + KafkaProducerRecord.class.getName(), e);
diff --git a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProvider.java b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProvider.java
index 2590b5f..fd87a69 100644
--- a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProvider.java
+++ b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProvider.java
@@ -23,7 +23,6 @@ import static org.apache.flink.statefun.flink.io.kafka.KafkaEgressSpecJsonParser
import static org.apache.flink.statefun.flink.io.kafka.KafkaEgressSpecJsonParser.kafkaClientProperties;
import static org.apache.flink.statefun.flink.io.kafka.KafkaEgressSpecJsonParser.optionalDeliverySemantic;
-import com.google.protobuf.Any;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
import org.apache.flink.statefun.flink.io.spi.JsonEgressSpec;
import org.apache.flink.statefun.flink.io.spi.SinkProvider;
@@ -31,6 +30,7 @@ import org.apache.flink.statefun.sdk.io.EgressIdentifier;
import org.apache.flink.statefun.sdk.io.EgressSpec;
import org.apache.flink.statefun.sdk.kafka.KafkaEgressBuilder;
import org.apache.flink.statefun.sdk.kafka.KafkaEgressSpec;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
final class GenericKafkaSinkProvider implements SinkProvider {
@@ -84,10 +84,10 @@ final class GenericKafkaSinkProvider implements SinkProvider {
private static void validateConsumedType(EgressIdentifier<?> id) {
Class<?> consumedType = id.consumedType();
- if (Any.class != consumedType) {
+ if (TypedValue.class != consumedType) {
throw new IllegalArgumentException(
"Generic Kafka egress is only able to consume messages types of "
- + Any.class.getName()
+ + TypedValue.class.getName()
+ " but "
+ consumedType.getName()
+ " is provided.");
diff --git a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisEgressSerializer.java b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisEgressSerializer.java
index 4b1c522..1459b15 100644
--- a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisEgressSerializer.java
+++ b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisEgressSerializer.java
@@ -18,18 +18,19 @@
package org.apache.flink.statefun.flink.io.kinesis.polyglot;
-import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
+import org.apache.flink.statefun.flink.common.types.TypedValueUtil;
import org.apache.flink.statefun.sdk.egress.generated.KinesisEgressRecord;
import org.apache.flink.statefun.sdk.kinesis.egress.EgressRecord;
import org.apache.flink.statefun.sdk.kinesis.egress.KinesisEgressSerializer;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
-public final class GenericKinesisEgressSerializer implements KinesisEgressSerializer<Any> {
+public final class GenericKinesisEgressSerializer implements KinesisEgressSerializer<TypedValue> {
private static final long serialVersionUID = 1L;
@Override
- public EgressRecord serialize(Any value) {
+ public EgressRecord serialize(TypedValue value) {
final KinesisEgressRecord kinesisEgressRecord = asKinesisEgressRecord(value);
final EgressRecord.Builder builder =
@@ -46,14 +47,14 @@ public final class GenericKinesisEgressSerializer implements KinesisEgressSerial
return builder.build();
}
- private static KinesisEgressRecord asKinesisEgressRecord(Any message) {
- if (!message.is(KinesisEgressRecord.class)) {
+ private static KinesisEgressRecord asKinesisEgressRecord(TypedValue message) {
+ if (!TypedValueUtil.isProtobufTypeOf(message, KinesisEgressRecord.getDescriptor())) {
throw new IllegalStateException(
"The generic Kinesis egress expects only messages of type "
+ KinesisEgressRecord.class.getName());
}
try {
- return message.unpack(KinesisEgressRecord.class);
+ return KinesisEgressRecord.parseFrom(message.getValue());
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(
"Unable to unpack message as a " + KinesisEgressRecord.class.getName(), e);
diff --git a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisSinkProvider.java b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisSinkProvider.java
index ad8fc1f..d5f5f29 100644
--- a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisSinkProvider.java
+++ b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisSinkProvider.java
@@ -22,7 +22,6 @@ import static org.apache.flink.statefun.flink.io.kinesis.polyglot.AwsAuthSpecJso
import static org.apache.flink.statefun.flink.io.kinesis.polyglot.KinesisEgressSpecJsonParser.clientConfigProperties;
import static org.apache.flink.statefun.flink.io.kinesis.polyglot.KinesisEgressSpecJsonParser.optionalMaxOutstandingRecords;
-import com.google.protobuf.Any;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
import org.apache.flink.statefun.flink.io.kinesis.KinesisSinkProvider;
import org.apache.flink.statefun.flink.io.spi.JsonEgressSpec;
@@ -31,6 +30,7 @@ import org.apache.flink.statefun.sdk.io.EgressIdentifier;
import org.apache.flink.statefun.sdk.io.EgressSpec;
import org.apache.flink.statefun.sdk.kinesis.egress.KinesisEgressBuilder;
import org.apache.flink.statefun.sdk.kinesis.egress.KinesisEgressSpec;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
public final class GenericKinesisSinkProvider implements SinkProvider {
@@ -74,10 +74,10 @@ public final class GenericKinesisSinkProvider implements SinkProvider {
private static void validateConsumedType(EgressIdentifier<?> id) {
Class<?> consumedType = id.consumedType();
- if (Any.class != consumedType) {
+ if (TypedValue.class != consumedType) {
throw new IllegalArgumentException(
"Generic Kinesis egress is only able to consume messages types of "
- + Any.class.getName()
+ + TypedValue.class.getName()
+ " but "
+ consumedType.getName()
+ " is provided.");
diff --git a/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProviderTest.java b/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProviderTest.java
index 151574d..d0dcc50 100644
--- a/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProviderTest.java
+++ b/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProviderTest.java
@@ -21,10 +21,10 @@ import static org.apache.flink.statefun.flink.io.testutils.YamlUtils.loadAsJsonF
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.MatcherAssert.assertThat;
-import com.google.protobuf.Any;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
import org.apache.flink.statefun.flink.io.spi.JsonEgressSpec;
import org.apache.flink.statefun.sdk.io.EgressIdentifier;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
import org.junit.Test;
@@ -38,7 +38,7 @@ public class GenericKafkaSinkProviderTest {
JsonEgressSpec<?> spec =
new JsonEgressSpec<>(
KafkaEgressTypes.GENERIC_KAFKA_EGRESS_TYPE,
- new EgressIdentifier<>("foo", "bar", Any.class),
+ new EgressIdentifier<>("foo", "bar", TypedValue.class),
egressDefinition);
GenericKafkaSinkProvider provider = new GenericKafkaSinkProvider();
diff --git a/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kinesis/GenericKinesisSinkProviderTest.java b/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kinesis/GenericKinesisSinkProviderTest.java
index 2a6b19b..adfc8f6 100644
--- a/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kinesis/GenericKinesisSinkProviderTest.java
+++ b/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kinesis/GenericKinesisSinkProviderTest.java
@@ -21,11 +21,11 @@ import static org.apache.flink.statefun.flink.io.testutils.YamlUtils.loadAsJsonF
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.MatcherAssert.assertThat;
-import com.google.protobuf.Any;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
import org.apache.flink.statefun.flink.io.kinesis.polyglot.GenericKinesisSinkProvider;
import org.apache.flink.statefun.flink.io.spi.JsonEgressSpec;
import org.apache.flink.statefun.sdk.io.EgressIdentifier;
+import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisProducer;
import org.junit.Test;
@@ -39,7 +39,7 @@ public class GenericKinesisSinkProviderTest {
JsonEgressSpec<?> spec =
new JsonEgressSpec<>(
PolyglotKinesisIOTypes.GENERIC_KINESIS_EGRESS_TYPE,
- new EgressIdentifier<>("foo", "bar", Any.class),
+ new EgressIdentifier<>("foo", "bar", TypedValue.class),
egressDefinition);
GenericKinesisSinkProvider provider = new GenericKinesisSinkProvider();
diff --git a/statefun-python-sdk/statefun/core.py b/statefun-python-sdk/statefun/core.py
index 8499a71..8e342d0 100644
--- a/statefun-python-sdk/statefun/core.py
+++ b/statefun-python-sdk/statefun/core.py
@@ -46,6 +46,13 @@ class AnyStateHandle(object):
self.modified = False
self.deleted = False
+ #
+ # TODO This should reflect the actual type URL.
+ # TODO we can support that only after reworking the SDK.
+ #
+ def typename(self):
+ return "type.googleapis.com/google.protobuf.Any"
+
def bytes(self):
if self.deleted:
raise AssertionError("can not obtain the bytes of a delete handle")
diff --git a/statefun-python-sdk/statefun/request_reply.py b/statefun-python-sdk/statefun/request_reply.py
index 6be41d0..f58e6d7 100644
--- a/statefun-python-sdk/statefun/request_reply.py
+++ b/statefun-python-sdk/statefun/request_reply.py
@@ -21,14 +21,13 @@ from google.protobuf.any_pb2 import Any
from statefun.core import SdkAddress
from statefun.core import Expiration
-from statefun.core import AnyStateHandle
from statefun.core import parse_typename
from statefun.core import StateRegistrationError
# generated function protocol
from statefun.request_reply_pb2 import FromFunction
from statefun.request_reply_pb2 import ToFunction
-
+from statefun.typed_value_utils import to_proto_any, from_proto_any, to_proto_any_state, from_proto_any_state
class InvocationContext:
def __init__(self, functions):
@@ -88,7 +87,7 @@ class InvocationContext:
@staticmethod
def provided_state_values(to_function):
- return {s.state_name: AnyStateHandle(s.state_value) for s in to_function.invocation.state}
+ return {s.state_name: to_proto_any_state(s.state_value) for s in to_function.invocation.state}
@staticmethod
def add_outgoing_messages(context, invocation_result):
@@ -100,7 +99,7 @@ class InvocationContext:
outgoing.target.namespace = namespace
outgoing.target.type = type
outgoing.target.id = id
- outgoing.argument.CopyFrom(message)
+ outgoing.argument.CopyFrom(from_proto_any(message))
@staticmethod
def add_mutations(context, invocation_result):
@@ -114,7 +113,7 @@ class InvocationContext:
mutation.mutation_type = FromFunction.PersistedValueMutation.MutationType.Value('DELETE')
else:
mutation.mutation_type = FromFunction.PersistedValueMutation.MutationType.Value('MODIFY')
- mutation.state_value = handle.bytes()
+ mutation.state_value.CopyFrom(from_proto_any_state(handle))
@staticmethod
def add_delayed_messages(context, invocation_result):
@@ -127,7 +126,7 @@ class InvocationContext:
outgoing.target.type = type
outgoing.target.id = id
outgoing.delay_in_ms = delay
- outgoing.argument.CopyFrom(message)
+ outgoing.argument.CopyFrom(from_proto_any(message))
@staticmethod
def add_egress(context, invocation_result):
@@ -138,7 +137,7 @@ class InvocationContext:
namespace, type = parse_typename(typename)
outgoing.egress_namespace = namespace
outgoing.egress_type = type
- outgoing.argument.CopyFrom(message)
+ outgoing.argument.CopyFrom(from_proto_any(message))
@staticmethod
def add_missing_state_specs(missing_state_specs, incomplete_context_response):
@@ -147,6 +146,10 @@ class InvocationContext:
missing_value = missing_values.add()
missing_value.state_name = state_spec.name
+ # TODO see the comment in typed_value_utils.from_proto_any_state on
+ # TODO the reason to use this specific typename
+ missing_value.type_typename = "type.googleapis.com/google.protobuf.Any"
+
protocol_expiration_spec = FromFunction.ExpirationSpec()
sdk_expiration_spec = state_spec.expiration
if not sdk_expiration_spec:
@@ -181,9 +184,10 @@ class RequestReplyHandler:
fun = target_function.func
for invocation in batch:
context.prepare(invocation)
- unpacked = target_function.unpack_any(invocation.argument)
+ any_arg = to_proto_any(invocation.argument)
+ unpacked = target_function.unpack_any(any_arg)
if not unpacked:
- fun(context, invocation.argument)
+ fun(context, any_arg)
else:
fun(context, unpacked)
@@ -207,9 +211,10 @@ class AsyncRequestReplyHandler:
fun = target_function.func
for invocation in batch:
context.prepare(invocation)
- unpacked = target_function.unpack_any(invocation.argument)
+ any_arg = to_proto_any(invocation.argument)
+ unpacked = target_function.unpack_any(any_arg)
if not unpacked:
- await fun(context, invocation.argument)
+ await fun(context, any_arg)
else:
await fun(context, unpacked)
diff --git a/statefun-python-sdk/statefun/typed_value_utils.py b/statefun-python-sdk/statefun/typed_value_utils.py
new file mode 100644
index 0000000..8706800
--- /dev/null
+++ b/statefun-python-sdk/statefun/typed_value_utils.py
@@ -0,0 +1,49 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from google.protobuf.any_pb2 import Any
+
+from statefun.core import AnyStateHandle
+from statefun.request_reply_pb2 import TypedValue
+
+#
+# Utility methods to covert back and forth from Protobuf Any to our TypedValue.
+# TODO this conversion needs to take place only because the Python SDK still works with Protobuf Any's
+# TODO this would soon go away by letting the SDK work directly with TypedValues.
+#
+
+def to_proto_any(typed_value: TypedValue):
+ proto_any = Any()
+ proto_any.type_url = typed_value.typename
+ proto_any.value = typed_value.value
+ return proto_any
+
+def from_proto_any(proto_any: Any):
+ typed_value = TypedValue()
+ typed_value.typename = proto_any.type_url
+ typed_value.value = proto_any.value
+ return typed_value
+
+def from_proto_any_state(any_state_handle: AnyStateHandle):
+ typed_value = TypedValue()
+ typed_value.typename = any_state_handle.typename()
+ typed_value.value = any_state_handle.bytes()
+ return typed_value
+
+def to_proto_any_state(typed_value: TypedValue) -> AnyStateHandle:
+ return AnyStateHandle(typed_value.value)
diff --git a/statefun-python-sdk/tests/request_reply_test.py b/statefun-python-sdk/tests/request_reply_test.py
index 80691f9..157bba6 100644
--- a/statefun-python-sdk/tests/request_reply_test.py
+++ b/statefun-python-sdk/tests/request_reply_test.py
@@ -23,7 +23,7 @@ from google.protobuf.json_format import MessageToDict
from google.protobuf.any_pb2 import Any
from tests.examples_pb2 import LoginEvent, SeenCount
-from statefun.request_reply_pb2 import ToFunction, FromFunction
+from statefun.request_reply_pb2 import ToFunction, FromFunction, TypedValue
from statefun import RequestReplyHandler, AsyncRequestReplyHandler
from statefun import StatefulFunctions, StateSpec, AfterWrite, StateRegistrationError
from statefun import kafka_egress_record, kinesis_egress_record
@@ -43,9 +43,7 @@ class InvocationBuilder(object):
state = self.to_function.invocation.state.add()
state.state_name = name
if value:
- any = Any()
- any.Pack(value)
- state.state_value = any.SerializeToString()
+ state.state_value.CopyFrom(self.to_typed_value_any_state(value))
return self
def with_invocation(self, arg, caller=None):
@@ -53,13 +51,31 @@ class InvocationBuilder(object):
if caller:
(ns, type, id) = caller
InvocationBuilder.set_address(ns, type, id, invocation.caller)
- invocation.argument.Pack(arg)
+ invocation.argument.CopyFrom(self.to_typed_value(arg))
return self
def SerializeToString(self):
return self.to_function.SerializeToString()
@staticmethod
+ def to_typed_value(proto_msg):
+ any = Any()
+ any.Pack(proto_msg)
+ typed_value = TypedValue()
+ typed_value.typename = any.type_url
+ typed_value.value = any.value
+ return typed_value
+
+ @staticmethod
+ def to_typed_value_any_state(proto_msg):
+ any = Any()
+ any.Pack(proto_msg)
+ typed_value = TypedValue()
+ typed_value.typename = "type.googleapis.com/google.protobuf.Any"
+ typed_value.value = any.SerializeToString()
+ return typed_value
+
+ @staticmethod
def set_address(namespace, type, id, address):
address.namespace = namespace
address.type = type
@@ -184,14 +200,14 @@ class RequestReplyTestCase(unittest.TestCase):
self.assertEqual(first_out_message['target']['namespace'], 'org.foo')
self.assertEqual(first_out_message['target']['type'], 'greeter-java')
self.assertEqual(first_out_message['target']['id'], '0')
- self.assertEqual(first_out_message['argument']['@type'], 'type.googleapis.com/k8s.demo.SeenCount')
+ self.assertEqual(first_out_message['argument']['typename'], 'type.googleapis.com/k8s.demo.SeenCount')
# assert second outgoing message
second_out_message = json_at(result_json, NTH_OUTGOING_MESSAGE(1))
self.assertEqual(second_out_message['target']['namespace'], 'bar.baz')
self.assertEqual(second_out_message['target']['type'], 'foo')
self.assertEqual(second_out_message['target']['id'], '12345')
- self.assertEqual(second_out_message['argument']['@type'], 'type.googleapis.com/k8s.demo.SeenCount')
+ self.assertEqual(second_out_message['argument']['typename'], 'type.googleapis.com/k8s.demo.SeenCount')
# assert state mutations
first_mutation = json_at(result_json, NTH_STATE_MUTATION(0))
@@ -207,7 +223,7 @@ class RequestReplyTestCase(unittest.TestCase):
first_egress = json_at(result_json, NTH_EGRESS(0))
self.assertEqual(first_egress['egress_namespace'], 'foo.bar.baz')
self.assertEqual(first_egress['egress_type'], 'my-egress')
- self.assertEqual(first_egress['argument']['@type'], 'type.googleapis.com/k8s.demo.SeenCount')
+ self.assertEqual(first_egress['argument']['typename'], 'type.googleapis.com/k8s.demo.SeenCount')
def test_integration_incomplete_context(self):
functions = StatefulFunctions()
@@ -309,7 +325,7 @@ class AsyncRequestReplyTestCase(unittest.TestCase):
self.assertEqual(second_out_message['target']['namespace'], 'bar.baz')
self.assertEqual(second_out_message['target']['type'], 'foo')
self.assertEqual(second_out_message['target']['id'], '12345')
- self.assertEqual(second_out_message['argument']['@type'], 'type.googleapis.com/k8s.demo.SeenCount')
+ self.assertEqual(second_out_message['argument']['typename'], 'type.googleapis.com/k8s.demo.SeenCount')
def test_integration_incomplete_context(self):
functions = StatefulFunctions()
diff --git a/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto b/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto
index 2ebd8f9..e0895a4 100644
--- a/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto
+++ b/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto
@@ -23,8 +23,6 @@ package io.statefun.sdk.reqreply;
option java_package = "org.apache.flink.statefun.sdk.reqreply.generated";
option java_multiple_files = true;
-import "google/protobuf/any.proto";
-
// -------------------------------------------------------------------------------------------------------------------
// Common message definitions
// -------------------------------------------------------------------------------------------------------------------
@@ -39,6 +37,11 @@ message Address {
string id = 3;
}
+message TypedValue {
+ string typename = 1;
+ bytes value = 2;
+}
+
// -------------------------------------------------------------------------------------------------------------------
// Messages sent to a Remote Function
// -------------------------------------------------------------------------------------------------------------------
@@ -51,7 +54,7 @@ message ToFunction {
// The unique name of the persisted state.
string state_name = 1;
// The serialized state value
- bytes state_value = 2;
+ TypedValue state_value = 2;
}
// Invocation represents a remote function call, it associated with an (optional) return address,
@@ -60,7 +63,7 @@ message ToFunction {
// The address of the function that requested the invocation (possibly absent)
Address caller = 1;
// The invocation argument (aka the message sent to the target function)
- google.protobuf.Any argument = 2;
+ TypedValue argument = 2;
}
// InvocationBatchRequest represents a request to invoke a remote function. It is always associated with a target
@@ -94,7 +97,7 @@ message FromFunction {
}
MutationType mutation_type = 1;
string state_name = 2;
- bytes state_value = 3;
+ TypedValue state_value = 3;
}
// Invocation represents a remote function call, it associated with a (mandatory) target address,
@@ -103,7 +106,7 @@ message FromFunction {
// The target function to invoke
Address target = 1;
// The invocation argument (aka the message sent to the target function)
- google.protobuf.Any argument = 2;
+ TypedValue argument = 2;
}
// DelayedInvocation represents a delayed remote function call with a target address, an argument
@@ -114,19 +117,19 @@ message FromFunction {
// the target address to send this message to
Address target = 2;
// the invocation argument
- google.protobuf.Any argument = 3;
+ TypedValue argument = 3;
}
// EgressMessage an argument to forward to an egress.
// An egress is identified by a namespace and type (see EgressIdentifier SDK class).
- // The argument is a google.protobuf.Any
+ // The argument is an io.statefun.sdk.reqreply.TypedValue.
message EgressMessage {
// The target egress namespace
string egress_namespace = 1;
// The target egress type
string egress_type = 2;
// egress argument
- google.protobuf.Any argument = 3;
+ TypedValue argument = 3;
}
// InvocationResponse represents a result of an io.statefun.sdk.reqreply.ToFunction.InvocationBatchRequest