You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tz...@apache.org on 2020/09/24 05:59:42 UTC
[flink-statefun] branch master updated: [FLINK-19176] Add pluggable
statefun payload serializer
This is an automated email from the ASF dual-hosted git repository.
tzulitai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-statefun.git
The following commit(s) were added to refs/heads/master by this push:
new f9968cd [FLINK-19176] Add pluggable statefun payload serializer
f9968cd is described below
commit f9968cdcdd6947127ceb5af0dc5389469273d99f
Author: Galen Warren <ga...@users.noreply.github.com>
AuthorDate: Mon Sep 14 23:07:02 2020 -0400
[FLINK-19176] Add pluggable statefun payload serializer
This closes #152.
---
docs/deployment-and-operations/configurations.md | 8 +-
.../flink/core/StatefulFunctionsConfig.java | 29 +++++
.../core/StatefulFunctionsConfigValidator.java | 25 +++++
.../flink/core/StatefulFunctionsUniverse.java | 14 +--
.../core/functions/FunctionGroupOperator.java | 2 +-
.../flink/core/message/MessageFactory.java | 31 +++++-
.../flink/core/message/MessageFactoryKey.java | 44 ++++++++
.../flink/core/message/MessageFactoryType.java | 3 +-
.../flink/core/message/MessageTypeInformation.java | 14 +--
.../flink/core/message/MessageTypeSerializer.java | 71 ++++++++++---
.../flink/statefun/flink/core/spi/Modules.java | 6 +-
.../core/translation/CheckpointToMessage.java | 12 +--
.../flink/core/translation/EmbeddedTranslator.java | 4 +-
.../core/translation/IngressRouterOperator.java | 12 +--
.../translation/StatefulFunctionTranslator.java | 2 +-
.../core/types/StaticallyRegisteredTypes.java | 10 +-
.../flink/core/StatefulFunctionsConfigTest.java | 34 +++++-
.../flink/statefun/flink/core/TestUtils.java | 3 +-
.../flink/core/functions/ReductionsTest.java | 4 +-
.../flink/core/jsonmodule/JsonModuleTest.java | 4 +-
.../flink/core/message/JavaPayloadSerializer.java | 67 ++++++++++++
.../statefun/flink/core/message/MessageTest.java | 21 ++--
.../message/MessageTypeSerializerSnapshotTest.java | 117 +++++++++++++++++++++
.../core/message/MessageTypeSerializerTest.java | 7 +-
.../operator/FunctionsStateBootstrapOperator.java | 4 +-
25 files changed, 474 insertions(+), 74 deletions(-)
diff --git a/docs/deployment-and-operations/configurations.md b/docs/deployment-and-operations/configurations.md
index d09ed25..46b16d3 100644
--- a/docs/deployment-and-operations/configurations.md
+++ b/docs/deployment-and-operations/configurations.md
@@ -47,7 +47,13 @@ These may be set through your job's ``flink-conf.yaml``.
<td><h5>statefun.message.serializer</h5></td>
<td style="word-wrap: break-word;">WITH_PROTOBUF_PAYLOADS</td>
<td>Message Serializer</td>
- <td>The serializer to use for on the wire messages. Options are WITH_PROTOBUF_PAYLOADS, WITH_KRYO_PAYLOADS, WITH_RAW_PAYLOADS.</td>
+ <td>The serializer to use for on the wire messages. Options are WITH_PROTOBUF_PAYLOADS, WITH_KRYO_PAYLOADS, WITH_RAW_PAYLOADS, WITH_CUSTOM_PAYLOADS.</td>
+ </tr>
+ <tr>
+ <td><h5>statefun.message.custom-payload-serializer-class</h5></td>
+ <td style="word-wrap: break-word;">(none)</td>
+ <td>String</td>
+ <td>The custom payload serializer class to use with the WITH_CUSTOM_PAYLOADS serializer, which must implement MessagePayloadSerializer.</td>
</tr>
<tr>
<td><h5>statefun.flink-job-name</h5></td>
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfig.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfig.java
index 4982571..a38d9a5 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfig.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfig.java
@@ -30,6 +30,7 @@ import org.apache.flink.configuration.ConfigOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.configuration.description.Description;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
import org.apache.flink.statefun.sdk.spi.StatefulFunctionModule;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -72,6 +73,13 @@ public class StatefulFunctionsConfig implements Serializable {
.defaultValue(MessageFactoryType.WITH_PROTOBUF_PAYLOADS)
.withDescription("The serializer to use for on the wire messages.");
+ public static final ConfigOption<String> USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS =
+ ConfigOptions.key("statefun.message.custom-payload-serializer-class")
+ .stringType()
+ .noDefaultValue()
+ .withDescription(
+ "The custom payload serializer class to use with the WITH_CUSTOM_PAYLOADS serializer, which must implement MessagePayloadSerializer.");
+
public static final ConfigOption<String> FLINK_JOB_NAME =
ConfigOptions.key("statefun.flink-job-name")
.stringType()
@@ -114,6 +122,8 @@ public class StatefulFunctionsConfig implements Serializable {
private MessageFactoryType factoryType;
+ private String customPayloadSerializerClassName;
+
private String flinkJobName;
private byte[] universeInitializerClassBytes;
@@ -133,6 +143,8 @@ public class StatefulFunctionsConfig implements Serializable {
*/
private StatefulFunctionsConfig(Configuration configuration) {
this.factoryType = configuration.get(USER_MESSAGE_SERIALIZER);
+ this.customPayloadSerializerClassName =
+ configuration.get(USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS);
this.flinkJobName = configuration.get(FLINK_JOB_NAME);
this.feedbackBufferSize = configuration.get(TOTAL_MEMORY_USED_FOR_FEEDBACK_CHECKPOINTING);
this.maxAsyncOperationsPerTask = configuration.get(ASYNC_MAX_OPERATIONS_PER_TASK);
@@ -152,11 +164,28 @@ public class StatefulFunctionsConfig implements Serializable {
return factoryType;
}
+ /**
+ * Returns the custom payload serializer class name, when factory type is WITH_CUSTOM_PAYLOADS *
+ */
+ public String getCustomPayloadSerializerClassName() {
+ return customPayloadSerializerClassName;
+ }
+
+ /** Returns the factory key * */
+ public MessageFactoryKey getFactoryKey() {
+ return MessageFactoryKey.forType(this.factoryType, this.customPayloadSerializerClassName);
+ }
+
/** Sets the factory type used to serialize messages. */
public void setFactoryType(MessageFactoryType factoryType) {
this.factoryType = Objects.requireNonNull(factoryType);
}
+ /** Sets the custom payload serializer class name * */
+ public void setCustomPayloadSerializerClassName(String customPayloadSerializerClassName) {
+ this.customPayloadSerializerClassName = customPayloadSerializerClassName;
+ }
+
/** Returns the Flink job name that appears in the Web UI. */
public String getFlinkJobName() {
return flinkJobName;
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigValidator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigValidator.java
index c4f658c..2f31d0d 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigValidator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigValidator.java
@@ -27,6 +27,8 @@ import java.util.Set;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.CoreOptions;
import org.apache.flink.statefun.flink.core.exceptions.StatefulFunctionsInvalidConfigException;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
+import org.apache.flink.util.StringUtils;
public final class StatefulFunctionsConfigValidator {
@@ -40,6 +42,7 @@ public final class StatefulFunctionsConfigValidator {
static void validate(Configuration configuration) {
validateParentFirstClassloaderPatterns(configuration);
+ validateCustomPayloadSerializerClassName(configuration);
}
private static void validateParentFirstClassloaderPatterns(Configuration configuration) {
@@ -61,4 +64,26 @@ public final class StatefulFunctionsConfigValidator {
}
return parentFirstClassloaderPatterns;
}
+
+ private static void validateCustomPayloadSerializerClassName(Configuration configuration) {
+
+ MessageFactoryType factoryType =
+ configuration.get(StatefulFunctionsConfig.USER_MESSAGE_SERIALIZER);
+ String customPayloadSerializerClassName =
+ configuration.get(StatefulFunctionsConfig.USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS);
+
+ if (factoryType == MessageFactoryType.WITH_CUSTOM_PAYLOADS) {
+ if (StringUtils.isNullOrWhitespaceOnly(customPayloadSerializerClassName)) {
+ throw new StatefulFunctionsInvalidConfigException(
+ StatefulFunctionsConfig.USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS,
+ "custom payload serializer class must be supplied with WITH_CUSTOM_PAYLOADS serializer");
+ }
+ } else {
+ if (customPayloadSerializerClassName != null) {
+ throw new StatefulFunctionsInvalidConfigException(
+ StatefulFunctionsConfig.USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS,
+ "custom payload serializer class may only be supplied with WITH_CUSTOM_PAYLOADS serializer");
+ }
+ }
+ }
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsUniverse.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsUniverse.java
index 3b48b3e..bbf6661 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsUniverse.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsUniverse.java
@@ -23,7 +23,7 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import javax.annotation.Nullable;
-import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
import org.apache.flink.statefun.flink.core.types.StaticallyRegisteredTypes;
import org.apache.flink.statefun.flink.io.spi.FlinkIoModule;
import org.apache.flink.statefun.flink.io.spi.SinkProvider;
@@ -50,11 +50,11 @@ public final class StatefulFunctionsUniverse
private final Map<EgressType, SinkProvider> sinks = new HashMap<>();
private final StaticallyRegisteredTypes types;
- private final MessageFactoryType messageFactoryType;
+ private final MessageFactoryKey messageFactoryKey;
- public StatefulFunctionsUniverse(MessageFactoryType messageFactoryType) {
- this.messageFactoryType = messageFactoryType;
- this.types = new StaticallyRegisteredTypes(messageFactoryType);
+ public StatefulFunctionsUniverse(MessageFactoryKey messageFactoryKey) {
+ this.messageFactoryKey = messageFactoryKey;
+ this.types = new StaticallyRegisteredTypes(messageFactoryKey);
}
@Override
@@ -138,7 +138,7 @@ public final class StatefulFunctionsUniverse
String.format("A binding for the key %s was previously defined.", key));
}
- public MessageFactoryType messageFactoryType() {
- return messageFactoryType;
+ public MessageFactoryKey messageFactoryKey() {
+ return messageFactoryKey;
}
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
index 381a672..79e7a53 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
@@ -133,7 +133,7 @@ public class FunctionGroupOperator extends AbstractStreamOperator<Message>
delayedMessagesBufferState(delayedMessageStateDescriptor),
sideOutputs,
output,
- MessageFactory.forType(statefulFunctionsUniverse.messageFactoryType()),
+ MessageFactory.forKey(statefulFunctionsUniverse.messageFactoryKey()),
new MailboxExecutorFacade(mailboxExecutor, "Stateful Functions Mailbox"),
getRuntimeContext().getMetricGroup().addGroup("functions"),
asyncOperationState);
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
index ee2b643..e4d2d3a 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
@@ -18,6 +18,7 @@
package org.apache.flink.statefun.flink.core.message;
import java.io.IOException;
+import java.lang.reflect.Constructor;
import java.util.Objects;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
@@ -29,8 +30,8 @@ import org.apache.flink.statefun.sdk.Address;
public final class MessageFactory {
- public static MessageFactory forType(MessageFactoryType type) {
- return new MessageFactory(forPayloadType(type));
+ public static MessageFactory forKey(MessageFactoryKey key) {
+ return new MessageFactory(forPayloadKey(key));
}
private final ProtobufSerializer<Envelope> envelopeSerializer;
@@ -93,8 +94,8 @@ public final class MessageFactory {
return Envelope.newBuilder().setCheckpoint(checkpoint).build();
}
- private static MessagePayloadSerializer forPayloadType(MessageFactoryType type) {
- switch (type) {
+ private static MessagePayloadSerializer forPayloadKey(MessageFactoryKey key) {
+ switch (key.getType()) {
case WITH_KRYO_PAYLOADS:
return new MessagePayloadSerializerKryo();
case WITH_PROTOBUF_PAYLOADS:
@@ -103,8 +104,28 @@ public final class MessageFactory {
return new MessagePayloadSerializerRaw();
case WITH_PROTOBUF_PAYLOADS_MULTILANG:
return new MessagePayloadSerializerMultiLanguage();
+ case WITH_CUSTOM_PAYLOADS:
+ String className =
+ key.getCustomPayloadSerializerClassName()
+ .orElseThrow(
+ () ->
+ new UnsupportedOperationException(
+ "WITH_CUSTOM_PAYLOADS requires custom payload serializer class name to be specified in MessageFactoryKey"));
+ return forCustomPayloadSerializer(className);
default:
- throw new IllegalArgumentException("unknown serialization method " + type);
+ throw new IllegalArgumentException("unknown serialization method " + key.getType());
+ }
+ }
+
+ private static MessagePayloadSerializer forCustomPayloadSerializer(String className) {
+ try {
+ Class<?> clazz =
+ Class.forName(className, true, Thread.currentThread().getContextClassLoader());
+ Constructor<?> constructor = clazz.getConstructor();
+ return (MessagePayloadSerializer) constructor.newInstance();
+ } catch (Throwable ex) {
+ throw new UnsupportedOperationException(
+ String.format("Failed to create custom payload serializer: %s", className), ex);
}
}
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryKey.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryKey.java
new file mode 100644
index 0000000..f5db425
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryKey.java
@@ -0,0 +1,44 @@
+package org.apache.flink.statefun.flink.core.message;
+
+import java.io.Serializable;
+import java.util.Objects;
+import java.util.Optional;
+
+public final class MessageFactoryKey implements Serializable {
+ private static final long serialVersionUID = 1L;
+
+ private final MessageFactoryType type;
+ private final String customPayloadSerializerClassName;
+
+ private MessageFactoryKey(MessageFactoryType type, String customPayloadSerializerClassName) {
+ this.type = Objects.requireNonNull(type);
+ this.customPayloadSerializerClassName = customPayloadSerializerClassName;
+ }
+
+ public static MessageFactoryKey forType(
+ MessageFactoryType type, String customPayloadSerializerClassName) {
+ return new MessageFactoryKey(type, customPayloadSerializerClassName);
+ }
+
+ public MessageFactoryType getType() {
+ return this.type;
+ }
+
+ public Optional<String> getCustomPayloadSerializerClassName() {
+ return Optional.ofNullable(customPayloadSerializerClassName);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ MessageFactoryKey that = (MessageFactoryKey) o;
+ return type == that.type
+ && Objects.equals(customPayloadSerializerClassName, that.customPayloadSerializerClassName);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(type, customPayloadSerializerClassName);
+ }
+}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryType.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryType.java
index 17e5a8b..404e5ed 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryType.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryType.java
@@ -21,5 +21,6 @@ public enum MessageFactoryType {
WITH_KRYO_PAYLOADS,
WITH_PROTOBUF_PAYLOADS,
WITH_RAW_PAYLOADS,
- WITH_PROTOBUF_PAYLOADS_MULTILANG
+ WITH_PROTOBUF_PAYLOADS_MULTILANG,
+ WITH_CUSTOM_PAYLOADS,
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeInformation.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeInformation.java
index 277b04d..90e8dd0 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeInformation.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeInformation.java
@@ -24,12 +24,12 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
public class MessageTypeInformation extends TypeInformation<Message> {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 2L;
- private final MessageFactoryType messageFactoryType;
+ private final MessageFactoryKey messageFactoryKey;
- public MessageTypeInformation(MessageFactoryType messageFactoryType) {
- this.messageFactoryType = Objects.requireNonNull(messageFactoryType);
+ public MessageTypeInformation(MessageFactoryKey messageFactoryKey) {
+ this.messageFactoryKey = Objects.requireNonNull(messageFactoryKey);
}
@Override
@@ -64,12 +64,14 @@ public class MessageTypeInformation extends TypeInformation<Message> {
@Override
public TypeSerializer<Message> createSerializer(ExecutionConfig executionConfig) {
- return new MessageTypeSerializer(messageFactoryType);
+ return new MessageTypeSerializer(messageFactoryKey);
}
@Override
public String toString() {
- return "MessageTypeInformation(" + messageFactoryType + ")";
+ return String.format(
+ "MessageTypeInformation(%s: %s",
+ messageFactoryKey.getType(), messageFactoryKey.getCustomPayloadSerializerClassName());
}
@Override
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializer.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializer.java
index 53c3265..6e78a5d 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializer.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializer.java
@@ -19,6 +19,7 @@ package org.apache.flink.statefun.flink.core.message;
import java.io.IOException;
import java.util.Objects;
+import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
@@ -27,16 +28,16 @@ import org.apache.flink.core.memory.DataOutputView;
public final class MessageTypeSerializer extends TypeSerializer<Message> {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 2L;
// -- configuration --
- private final MessageFactoryType messageFactoryType;
+ private final MessageFactoryKey messageFactoryKey;
// -- runtime --
private transient MessageFactory factory;
- MessageTypeSerializer(MessageFactoryType messageFactoryType) {
- this.messageFactoryType = Objects.requireNonNull(messageFactoryType);
+ MessageTypeSerializer(MessageFactoryKey messageFactoryKey) {
+ this.messageFactoryKey = Objects.requireNonNull(messageFactoryKey);
}
@Override
@@ -46,7 +47,7 @@ public final class MessageTypeSerializer extends TypeSerializer<Message> {
@Override
public TypeSerializer<Message> duplicate() {
- return new MessageTypeSerializer(messageFactoryType);
+ return new MessageTypeSerializer(messageFactoryKey);
}
@Override
@@ -101,45 +102,67 @@ public final class MessageTypeSerializer extends TypeSerializer<Message> {
@Override
public TypeSerializerSnapshot<Message> snapshotConfiguration() {
- return new Snapshot(messageFactoryType);
+ return new Snapshot(messageFactoryKey);
}
private MessageFactory factory() {
if (factory == null) {
- factory = MessageFactory.forType(messageFactoryType);
+ factory = MessageFactory.forKey(messageFactoryKey);
}
return factory;
}
public static final class Snapshot implements TypeSerializerSnapshot<Message> {
- private MessageFactoryType messageFactoryType;
+ private MessageFactoryKey messageFactoryKey;
@SuppressWarnings("unused")
public Snapshot() {}
- Snapshot(MessageFactoryType messageFactoryType) {
- this.messageFactoryType = messageFactoryType;
+ Snapshot(MessageFactoryKey messageFactoryKey) {
+ this.messageFactoryKey = messageFactoryKey;
+ }
+
+ @VisibleForTesting
+ MessageFactoryKey getMessageFactoryKey() {
+ return messageFactoryKey;
}
@Override
public int getCurrentVersion() {
- return 1;
+ return 2;
}
@Override
public void writeSnapshot(DataOutputView dataOutputView) throws IOException {
- dataOutputView.writeUTF(messageFactoryType.name());
+
+ // version 1
+ dataOutputView.writeUTF(messageFactoryKey.getType().name());
+
+ // added in version 2
+ writeNullableString(
+ messageFactoryKey.getCustomPayloadSerializerClassName().orElse(null), dataOutputView);
}
@Override
public void readSnapshot(int version, DataInputView dataInputView, ClassLoader classLoader)
throws IOException {
- messageFactoryType = MessageFactoryType.valueOf(dataInputView.readUTF());
+
+ // read values and assign defaults appropriate for version 1
+ MessageFactoryType messageFactoryType = MessageFactoryType.valueOf(dataInputView.readUTF());
+ String customPayloadSerializerClassName = null;
+
+ // if at least version 2, read in the custom payload serializer class name
+ if (version >= 2) {
+ customPayloadSerializerClassName = readNullableString(dataInputView);
+ }
+
+ this.messageFactoryKey =
+ MessageFactoryKey.forType(messageFactoryType, customPayloadSerializerClassName);
}
@Override
public TypeSerializer<Message> restoreSerializer() {
- return new MessageTypeSerializer(messageFactoryType);
+ return new MessageTypeSerializer(messageFactoryKey);
}
@Override
@@ -149,10 +172,28 @@ public final class MessageTypeSerializer extends TypeSerializer<Message> {
return TypeSerializerSchemaCompatibility.incompatible();
}
MessageTypeSerializer casted = (MessageTypeSerializer) typeSerializer;
- if (casted.messageFactoryType == messageFactoryType) {
+ if (casted.messageFactoryKey.equals(messageFactoryKey)) {
return TypeSerializerSchemaCompatibility.compatibleAsIs();
}
return TypeSerializerSchemaCompatibility.incompatible();
}
+
+ private static void writeNullableString(String value, DataOutputView out) throws IOException {
+ if (value != null) {
+ out.writeBoolean(true);
+ out.writeUTF(value);
+ } else {
+ out.writeBoolean(false);
+ }
+ }
+
+ private static String readNullableString(DataInputView in) throws IOException {
+ boolean isPresent = in.readBoolean();
+ if (isPresent) {
+ return in.readUTF();
+ } else {
+ return null;
+ }
+ }
}
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/spi/Modules.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/spi/Modules.java
index 1f01986..6ca5d24 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/spi/Modules.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/spi/Modules.java
@@ -22,7 +22,7 @@ import org.apache.flink.statefun.flink.common.SetContextClassLoader;
import org.apache.flink.statefun.flink.core.StatefulFunctionsConfig;
import org.apache.flink.statefun.flink.core.StatefulFunctionsUniverse;
import org.apache.flink.statefun.flink.core.jsonmodule.JsonServiceLoader;
-import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
import org.apache.flink.statefun.flink.io.spi.FlinkIoModule;
import org.apache.flink.statefun.sdk.spi.StatefulFunctionModule;
@@ -54,9 +54,9 @@ public final class Modules {
public StatefulFunctionsUniverse createStatefulFunctionsUniverse(
StatefulFunctionsConfig configuration) {
- MessageFactoryType factoryType = configuration.getFactoryType();
+ MessageFactoryKey factoryKey = configuration.getFactoryKey();
- StatefulFunctionsUniverse universe = new StatefulFunctionsUniverse(factoryType);
+ StatefulFunctionsUniverse universe = new StatefulFunctionsUniverse(factoryKey);
final Map<String, String> globalConfiguration = configuration.getGlobalConfigurations();
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/CheckpointToMessage.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/CheckpointToMessage.java
index 29eb07c..3dd2cc3 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/CheckpointToMessage.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/CheckpointToMessage.java
@@ -21,17 +21,17 @@ import java.io.Serializable;
import java.util.function.LongFunction;
import org.apache.flink.statefun.flink.core.message.Message;
import org.apache.flink.statefun.flink.core.message.MessageFactory;
-import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
final class CheckpointToMessage implements Serializable, LongFunction<Message> {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 2L;
- private final MessageFactoryType messageFactoryType;
+ private final MessageFactoryKey messageFactoryKey;
private transient MessageFactory factory;
- CheckpointToMessage(MessageFactoryType messageFactoryType) {
- this.messageFactoryType = messageFactoryType;
+ CheckpointToMessage(MessageFactoryKey messageFactoryKey) {
+ this.messageFactoryKey = messageFactoryKey;
}
@Override
@@ -41,7 +41,7 @@ final class CheckpointToMessage implements Serializable, LongFunction<Message> {
private MessageFactory factory() {
if (factory == null) {
- factory = MessageFactory.forType(messageFactoryType);
+ factory = MessageFactory.forKey(messageFactoryKey);
}
return factory;
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/EmbeddedTranslator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/EmbeddedTranslator.java
index 01df738..2e09158 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/EmbeddedTranslator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/EmbeddedTranslator.java
@@ -51,7 +51,7 @@ public class EmbeddedTranslator {
configuration.setProvider(new EmbeddedUniverseProvider<>(functions));
- StaticallyRegisteredTypes types = new StaticallyRegisteredTypes(configuration.getFactoryType());
+ StaticallyRegisteredTypes types = new StaticallyRegisteredTypes(configuration.getFactoryKey());
Sources sources = Sources.create(types, ingresses);
Sinks sinks = Sinks.create(types, egressesIds);
@@ -75,7 +75,7 @@ public class EmbeddedTranslator {
@Override
public StatefulFunctionsUniverse get(
ClassLoader classLoader, StatefulFunctionsConfig configuration) {
- StatefulFunctionsUniverse u = new StatefulFunctionsUniverse(configuration.getFactoryType());
+ StatefulFunctionsUniverse u = new StatefulFunctionsUniverse(configuration.getFactoryKey());
functions.forEach(u::bindFunctionProvider);
return u;
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/IngressRouterOperator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/IngressRouterOperator.java
index b2de476..dc002bc 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/IngressRouterOperator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/IngressRouterOperator.java
@@ -27,6 +27,7 @@ import org.apache.flink.statefun.flink.core.StatefulFunctionsUniverse;
import org.apache.flink.statefun.flink.core.StatefulFunctionsUniverses;
import org.apache.flink.statefun.flink.core.message.Message;
import org.apache.flink.statefun.flink.core.message.MessageFactory;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
import org.apache.flink.statefun.sdk.Address;
import org.apache.flink.statefun.sdk.io.IngressIdentifier;
@@ -67,9 +68,9 @@ public final class IngressRouterOperator<T> extends AbstractStreamOperator<Messa
StatefulFunctionsUniverses.get(
Thread.currentThread().getContextClassLoader(), configuration);
- LOG.info("Using message factory type " + universe.messageFactoryType());
+ LOG.info("Using message factory key " + universe.messageFactoryKey());
- this.downstream = new DownstreamCollector<>(universe.messageFactoryType(), output);
+ this.downstream = new DownstreamCollector<>(universe.messageFactoryKey(), output);
this.routers = loadRoutersAttachedToIngress(id, universe.routers());
}
@@ -98,12 +99,11 @@ public final class IngressRouterOperator<T> extends AbstractStreamOperator<Messa
private final StreamRecord<Message> reuse = new StreamRecord<>(null);
private final Output<StreamRecord<Message>> output;
- DownstreamCollector(
- MessageFactoryType messageFactoryType, Output<StreamRecord<Message>> output) {
- this.factory = MessageFactory.forType(messageFactoryType);
+ DownstreamCollector(MessageFactoryKey messageFactoryKey, Output<StreamRecord<Message>> output) {
+ this.factory = MessageFactory.forKey(messageFactoryKey);
this.output = Objects.requireNonNull(output);
this.multiLanguagePayloads =
- messageFactoryType == MessageFactoryType.WITH_PROTOBUF_PAYLOADS_MULTILANG;
+ messageFactoryKey.getType() == MessageFactoryType.WITH_PROTOBUF_PAYLOADS_MULTILANG;
}
@Override
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/StatefulFunctionTranslator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/StatefulFunctionTranslator.java
index 48b1bca..5ea77d3 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/StatefulFunctionTranslator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/StatefulFunctionTranslator.java
@@ -91,7 +91,7 @@ final class StatefulFunctionTranslator {
private SingleOutputStreamOperator<Void> feedbackOperator(
SingleOutputStreamOperator<Message> functionOut) {
- LongFunction<Message> toMessage = new CheckpointToMessage(configuration.getFactoryType());
+ LongFunction<Message> toMessage = new CheckpointToMessage(configuration.getFactoryKey());
FeedbackSinkOperator<Message> sinkOperator = new FeedbackSinkOperator<>(feedbackKey, toMessage);
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/types/StaticallyRegisteredTypes.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/types/StaticallyRegisteredTypes.java
index dd2c743..2b976c7 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/types/StaticallyRegisteredTypes.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/types/StaticallyRegisteredTypes.java
@@ -24,7 +24,7 @@ import javax.annotation.Nullable;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.statefun.flink.common.protobuf.ProtobufTypeInformation;
-import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
import org.apache.flink.statefun.flink.core.message.MessageTypeInformation;
/**
@@ -37,11 +37,11 @@ public final class StaticallyRegisteredTypes {
private final Map<Class<?>, TypeInformation<?>> registeredTypes = new HashMap<>();
- public StaticallyRegisteredTypes(MessageFactoryType messageFactoryType) {
- this.messageFactoryType = messageFactoryType;
+ public StaticallyRegisteredTypes(MessageFactoryKey messageFactoryKey) {
+ this.messageFactoryKey = messageFactoryKey;
}
- private final MessageFactoryType messageFactoryType;
+ private final MessageFactoryKey messageFactoryKey;
public <T> TypeInformation<T> registerType(Class<T> type) {
return (TypeInformation<T>) registeredTypes.computeIfAbsent(type, this::typeInformation);
@@ -62,7 +62,7 @@ public final class StaticallyRegisteredTypes {
return new ProtobufTypeInformation<>(message);
}
if (org.apache.flink.statefun.flink.core.message.Message.class.isAssignableFrom(valueType)) {
- return new MessageTypeInformation(messageFactoryType);
+ return new MessageTypeInformation(messageFactoryKey);
}
// TODO: we may want to restrict the allowed typeInfo here to theses that respect shcema
// evaluation.
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigTest.java
index baa704b..dc794bb 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigTest.java
@@ -17,9 +17,11 @@
*/
package org.apache.flink.statefun.flink.core;
+import java.util.Optional;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.CoreOptions;
import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.statefun.flink.core.exceptions.StatefulFunctionsInvalidConfigException;
import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.hamcrest.Matchers;
@@ -28,6 +30,8 @@ import org.junit.Test;
public class StatefulFunctionsConfigTest {
+ private final String serializerClassName = "com.sample.Serializer";
+
@Test
public void testSetConfigurations() {
final String testName = "test-name";
@@ -35,7 +39,9 @@ public class StatefulFunctionsConfigTest {
Configuration configuration = new Configuration();
configuration.set(StatefulFunctionsConfig.FLINK_JOB_NAME, testName);
configuration.set(
- StatefulFunctionsConfig.USER_MESSAGE_SERIALIZER, MessageFactoryType.WITH_KRYO_PAYLOADS);
+ StatefulFunctionsConfig.USER_MESSAGE_SERIALIZER, MessageFactoryType.WITH_CUSTOM_PAYLOADS);
+ configuration.set(
+ StatefulFunctionsConfig.USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS, serializerClassName);
configuration.set(
StatefulFunctionsConfig.TOTAL_MEMORY_USED_FOR_FEEDBACK_CHECKPOINTING,
MemorySize.ofMebiBytes(100));
@@ -51,7 +57,11 @@ public class StatefulFunctionsConfigTest {
StatefulFunctionsConfig.fromFlinkConfiguration(configuration);
Assert.assertEquals(stateFunConfig.getFlinkJobName(), testName);
- Assert.assertEquals(stateFunConfig.getFactoryType(), MessageFactoryType.WITH_KRYO_PAYLOADS);
+ Assert.assertEquals(
+ stateFunConfig.getFactoryKey().getType(), MessageFactoryType.WITH_CUSTOM_PAYLOADS);
+ Assert.assertEquals(
+ stateFunConfig.getFactoryKey().getCustomPayloadSerializerClassName(),
+ Optional.of(serializerClassName));
Assert.assertEquals(stateFunConfig.getFeedbackBufferSize(), MemorySize.ofMebiBytes(100));
Assert.assertEquals(stateFunConfig.getMaxAsyncOperationsPerTask(), 100);
Assert.assertThat(
@@ -60,7 +70,7 @@ public class StatefulFunctionsConfigTest {
stateFunConfig.getGlobalConfigurations(), Matchers.hasEntry("key2", "value2"));
}
- private static Configuration validConfiguration() {
+ private static Configuration baseConfiguration() {
Configuration configuration = new Configuration();
configuration.set(StatefulFunctionsConfig.FLINK_JOB_NAME, "name");
configuration.set(
@@ -75,4 +85,22 @@ public class StatefulFunctionsConfigTest {
configuration.set(ExecutionCheckpointingOptions.MAX_CONCURRENT_CHECKPOINTS, 1);
return configuration;
}
+
+ @Test(expected = StatefulFunctionsInvalidConfigException.class)
+ public void invalidCustomSerializerThrows() {
+ Configuration configuration = baseConfiguration();
+ configuration.set(
+ StatefulFunctionsConfig.USER_MESSAGE_SERIALIZER, MessageFactoryType.WITH_CUSTOM_PAYLOADS);
+ StatefulFunctionsConfigValidator.validate(configuration);
+ }
+
+ @Test(expected = StatefulFunctionsInvalidConfigException.class)
+ public void invalidNonCustomSerializerThrows() {
+ Configuration configuration = baseConfiguration();
+ configuration.set(
+ StatefulFunctionsConfig.USER_MESSAGE_SERIALIZER, MessageFactoryType.WITH_KRYO_PAYLOADS);
+ configuration.set(
+ StatefulFunctionsConfig.USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS, serializerClassName);
+ StatefulFunctionsConfigValidator.validate(configuration);
+ }
}
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/TestUtils.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/TestUtils.java
index ac3ad88..c49270e 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/TestUtils.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/TestUtils.java
@@ -19,6 +19,7 @@ package org.apache.flink.statefun.flink.core;
import org.apache.flink.statefun.flink.core.generated.EnvelopeAddress;
import org.apache.flink.statefun.flink.core.message.MessageFactory;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
import org.apache.flink.statefun.sdk.Address;
import org.apache.flink.statefun.sdk.FunctionType;
@@ -27,7 +28,7 @@ import org.apache.flink.statefun.sdk.FunctionType;
public class TestUtils {
public static final MessageFactory ENVELOPE_FACTORY =
- MessageFactory.forType(MessageFactoryType.WITH_KRYO_PAYLOADS);
+ MessageFactory.forKey(MessageFactoryKey.forType(MessageFactoryType.WITH_KRYO_PAYLOADS, null));
public static final FunctionType FUNCTION_TYPE = new FunctionType("test", "a");
public static final Address FUNCTION_1_ADDR = new Address(FUNCTION_TYPE, "a-1");
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
index bdb67e3..d389841 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
@@ -73,6 +73,7 @@ import org.apache.flink.statefun.flink.core.StatefulFunctionsUniverse;
import org.apache.flink.statefun.flink.core.TestUtils;
import org.apache.flink.statefun.flink.core.backpressure.ThresholdBackPressureValve;
import org.apache.flink.statefun.flink.core.message.Message;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
import org.apache.flink.streaming.api.operators.InternalTimerService;
import org.apache.flink.streaming.api.operators.Output;
@@ -91,7 +92,8 @@ public class ReductionsTest {
Reductions reductions =
Reductions.create(
new ThresholdBackPressureValve(-1),
- new StatefulFunctionsUniverse(MessageFactoryType.WITH_KRYO_PAYLOADS),
+ new StatefulFunctionsUniverse(
+ MessageFactoryKey.forType(MessageFactoryType.WITH_KRYO_PAYLOADS, null)),
new FakeRuntimeContext(),
new FakeKeyedStateBackend(),
new FakeTimerServiceFactory(),
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java
index 44b5bbe..3001684 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java
@@ -28,6 +28,7 @@ import java.util.Collection;
import java.util.Collections;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.statefun.flink.core.StatefulFunctionsUniverse;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
import org.apache.flink.statefun.sdk.FunctionType;
import org.apache.flink.statefun.sdk.io.EgressIdentifier;
@@ -118,6 +119,7 @@ public class JsonModuleTest {
}
private static StatefulFunctionsUniverse emptyUniverse() {
- return new StatefulFunctionsUniverse(MessageFactoryType.WITH_PROTOBUF_PAYLOADS);
+ return new StatefulFunctionsUniverse(
+ MessageFactoryKey.forType(MessageFactoryType.WITH_PROTOBUF_PAYLOADS, null));
}
}
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/JavaPayloadSerializer.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/JavaPayloadSerializer.java
new file mode 100644
index 0000000..bb686c3
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/JavaPayloadSerializer.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.statefun.flink.core.message;
+
+import com.google.protobuf.ByteString;
+import java.io.*;
+import javax.annotation.Nonnull;
+import org.apache.flink.statefun.flink.core.generated.Payload;
+
+// this is a payload serializer that uses normal java serialization, used for testing custom payload
+// serialization
+public class JavaPayloadSerializer implements MessagePayloadSerializer {
+
+ @Override
+ public Payload serialize(@Nonnull Object payloadObject) {
+ try {
+ String className = payloadObject.getClass().getName();
+ try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
+ try (ObjectOutputStream out = new ObjectOutputStream(bos)) {
+ out.writeObject(payloadObject);
+ out.flush();
+ byte[] bytes = bos.toByteArray();
+ return Payload.newBuilder()
+ .setClassName(className)
+ .setPayloadBytes(ByteString.copyFrom(bytes))
+ .build();
+ }
+ }
+ } catch (Throwable ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+
+ @Override
+ public Object deserialize(@Nonnull ClassLoader targetClassLoader, @Nonnull Payload payload) {
+ try {
+ try (ByteArrayInputStream bis =
+ new ByteArrayInputStream(payload.getPayloadBytes().toByteArray())) {
+ try (ObjectInput in = new ObjectInputStream(bis)) {
+ return in.readObject();
+ }
+ }
+ } catch (Throwable ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+
+ @Override
+ public Object copy(@Nonnull ClassLoader targetClassLoader, @Nonnull Object what) {
+ return deserialize(targetClassLoader, serialize(what));
+ }
+}
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTest.java
index 04056ad..31a9c35 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTest.java
@@ -34,27 +34,36 @@ import org.junit.runners.Parameterized.Parameters;
@RunWith(Parameterized.class)
public class MessageTest {
private final MessageFactoryType type;
+ private final String customPayloadSerializerClassName;
private final Object payload;
- public MessageTest(MessageFactoryType type, Object payload) {
+ public MessageTest(
+ MessageFactoryType type, String customPayloadSerializerClassName, Object payload) {
this.type = type;
+ this.customPayloadSerializerClassName = customPayloadSerializerClassName;
this.payload = payload;
}
@Parameters(name = "{0}")
public static Iterable<? extends Object[]> data() {
return Arrays.asList(
- new Object[] {MessageFactoryType.WITH_KRYO_PAYLOADS, DUMMY_PAYLOAD},
- new Object[] {MessageFactoryType.WITH_PROTOBUF_PAYLOADS, DUMMY_PAYLOAD},
- new Object[] {MessageFactoryType.WITH_RAW_PAYLOADS, DUMMY_PAYLOAD.toByteArray()},
+ new Object[] {MessageFactoryType.WITH_KRYO_PAYLOADS, null, DUMMY_PAYLOAD},
+ new Object[] {MessageFactoryType.WITH_PROTOBUF_PAYLOADS, null, DUMMY_PAYLOAD},
+ new Object[] {MessageFactoryType.WITH_RAW_PAYLOADS, null, DUMMY_PAYLOAD.toByteArray()},
new Object[] {
- MessageFactoryType.WITH_PROTOBUF_PAYLOADS_MULTILANG, Any.pack(DUMMY_PAYLOAD)
+ MessageFactoryType.WITH_PROTOBUF_PAYLOADS_MULTILANG, null, Any.pack(DUMMY_PAYLOAD)
+ },
+ new Object[] {
+ MessageFactoryType.WITH_CUSTOM_PAYLOADS,
+ "org.apache.flink.statefun.flink.core.message.JavaPayloadSerializer",
+ DUMMY_PAYLOAD
});
}
@Test
public void roundTrip() throws IOException {
- MessageFactory factory = MessageFactory.forType(type);
+ MessageFactory factory =
+ MessageFactory.forKey(MessageFactoryKey.forType(type, customPayloadSerializerClassName));
Message fromSdk = factory.from(FUNCTION_1_ADDR, FUNCTION_2_ADDR, payload);
DataOutputSerializer out = new DataOutputSerializer(32);
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerSnapshotTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerSnapshotTest.java
new file mode 100644
index 0000000..f9c0c78
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerSnapshotTest.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.statefun.flink.core.message;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.Arrays;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class MessageTypeSerializerSnapshotTest {
+
+ private static final String serializerClassName = "com.domain.Serializer";
+
+ private static class SnapshotData {
+ public int version;
+ public byte[] bytes;
+ }
+
+ private static interface SnapshotDataProvider {
+ SnapshotData provide(MessageFactoryKey messageFactoryKey) throws IOException;
+ }
+
+ private final MessageFactoryKey messageFactoryKey;
+ private final SnapshotDataProvider snapshotDataProvider;
+
+ public MessageTypeSerializerSnapshotTest(
+ MessageFactoryKey messageFactoryKey, SnapshotDataProvider snapshotDataProvider) {
+ this.messageFactoryKey = messageFactoryKey;
+ this.snapshotDataProvider = snapshotDataProvider;
+ }
+
+ @Parameterized.Parameters(name = "{0}")
+ public static Iterable<? extends Object[]> data() throws IOException {
+
+ MessageFactoryKey kryoFactoryKey =
+ MessageFactoryKey.forType(MessageFactoryType.WITH_KRYO_PAYLOADS, null);
+ MessageFactoryKey customFactoryKey =
+ MessageFactoryKey.forType(MessageFactoryType.WITH_CUSTOM_PAYLOADS, serializerClassName);
+
+ // generates snapshot data for V1, without customPayloadSerializerClassName
+ SnapshotDataProvider snapshotDataProviderV1 =
+ messageFactoryKey -> {
+ try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
+ DataOutputView dataOutputView = new DataOutputViewStreamWrapper(bos);
+ dataOutputView.writeUTF(messageFactoryKey.getType().name());
+ return new SnapshotData() {
+ {
+ version = 1;
+ bytes = bos.toByteArray();
+ }
+ };
+ }
+ };
+
+ // generates snapshot data for V2, the current version
+ SnapshotDataProvider snapshotDataProviderV2 =
+ messageFactoryKey -> {
+ MessageTypeSerializer.Snapshot snapshot =
+ new MessageTypeSerializer.Snapshot(messageFactoryKey);
+ try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
+ DataOutputView dataOutputView = new DataOutputViewStreamWrapper(bos);
+ snapshot.writeSnapshot(dataOutputView);
+ return new SnapshotData() {
+ {
+ version = 2;
+ bytes = bos.toByteArray();
+ }
+ };
+ }
+ };
+
+ return Arrays.asList(
+ new Object[] {kryoFactoryKey, snapshotDataProviderV1},
+ new Object[] {kryoFactoryKey, snapshotDataProviderV2},
+ new Object[] {customFactoryKey, snapshotDataProviderV2});
+ }
+
+ @Test
+ public void roundTrip() throws IOException {
+
+ SnapshotData snapshotData = this.snapshotDataProvider.provide(this.messageFactoryKey);
+ MessageTypeSerializer.Snapshot snapshot =
+ new MessageTypeSerializer.Snapshot(this.messageFactoryKey);
+ ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
+
+ try (ByteArrayInputStream bis = new ByteArrayInputStream(snapshotData.bytes)) {
+ DataInputView dataInputView = new DataInputViewStreamWrapper(bis);
+ snapshot.readSnapshot(snapshotData.version, dataInputView, classLoader);
+ }
+
+ // make sure the deserialized state matches what was used to serialize
+ assert (snapshot.getMessageFactoryKey().equals(this.messageFactoryKey));
+ }
+}
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerTest.java
index 80c1945..89c7726 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerTest.java
@@ -39,7 +39,9 @@ public class MessageTypeSerializerTest extends SerializerTestBase<Message> {
Message b = (Message) o2;
DataOutputSerializer aOut = new DataOutputSerializer(32);
DataOutputSerializer bOut = new DataOutputSerializer(32);
- MessageFactory factory = MessageFactory.forType(MessageFactoryType.WITH_KRYO_PAYLOADS);
+ MessageFactory factory =
+ MessageFactory.forKey(
+ MessageFactoryKey.forType(MessageFactoryType.WITH_KRYO_PAYLOADS, null));
try {
a.writeTo(factory, aOut);
} catch (IOException e) {
@@ -57,7 +59,8 @@ public class MessageTypeSerializerTest extends SerializerTestBase<Message> {
@Override
protected TypeSerializer<Message> createSerializer() {
- return new MessageTypeInformation(MessageFactoryType.WITH_KRYO_PAYLOADS)
+ return new MessageTypeInformation(
+ MessageFactoryKey.forType(MessageFactoryType.WITH_KRYO_PAYLOADS, null))
.createSerializer(new ExecutionConfig());
}
diff --git a/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/FunctionsStateBootstrapOperator.java b/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/FunctionsStateBootstrapOperator.java
index 09e6b59..ec33811 100644
--- a/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/FunctionsStateBootstrapOperator.java
+++ b/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/FunctionsStateBootstrapOperator.java
@@ -24,6 +24,7 @@ import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.state.api.output.SnapshotUtils;
import org.apache.flink.state.api.output.TaggedOperatorSubtaskState;
import org.apache.flink.statefun.flink.core.functions.FunctionGroupOperator;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
import org.apache.flink.statefun.flink.core.state.FlinkState;
import org.apache.flink.statefun.flink.core.state.State;
@@ -95,6 +96,7 @@ public final class FunctionsStateBootstrapOperator
runtimeContext,
keyedStateBackend,
new DynamicallyRegisteredTypes(
- new StaticallyRegisteredTypes(MessageFactoryType.WITH_RAW_PAYLOADS)));
+ new StaticallyRegisteredTypes(
+ MessageFactoryKey.forType(MessageFactoryType.WITH_RAW_PAYLOADS, null))));
}
}