You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tz...@apache.org on 2020/09/24 05:59:42 UTC

[flink-statefun] branch master updated: [FLINK-19176] Add pluggable statefun payload serializer

This is an automated email from the ASF dual-hosted git repository.

tzulitai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-statefun.git


The following commit(s) were added to refs/heads/master by this push:
     new f9968cd  [FLINK-19176] Add pluggable statefun payload serializer
f9968cd is described below

commit f9968cdcdd6947127ceb5af0dc5389469273d99f
Author: Galen Warren <ga...@users.noreply.github.com>
AuthorDate: Mon Sep 14 23:07:02 2020 -0400

    [FLINK-19176] Add pluggable statefun payload serializer
    
    This closes #152.
---
 docs/deployment-and-operations/configurations.md   |   8 +-
 .../flink/core/StatefulFunctionsConfig.java        |  29 +++++
 .../core/StatefulFunctionsConfigValidator.java     |  25 +++++
 .../flink/core/StatefulFunctionsUniverse.java      |  14 +--
 .../core/functions/FunctionGroupOperator.java      |   2 +-
 .../flink/core/message/MessageFactory.java         |  31 +++++-
 .../flink/core/message/MessageFactoryKey.java      |  44 ++++++++
 .../flink/core/message/MessageFactoryType.java     |   3 +-
 .../flink/core/message/MessageTypeInformation.java |  14 +--
 .../flink/core/message/MessageTypeSerializer.java  |  71 ++++++++++---
 .../flink/statefun/flink/core/spi/Modules.java     |   6 +-
 .../core/translation/CheckpointToMessage.java      |  12 +--
 .../flink/core/translation/EmbeddedTranslator.java |   4 +-
 .../core/translation/IngressRouterOperator.java    |  12 +--
 .../translation/StatefulFunctionTranslator.java    |   2 +-
 .../core/types/StaticallyRegisteredTypes.java      |  10 +-
 .../flink/core/StatefulFunctionsConfigTest.java    |  34 +++++-
 .../flink/statefun/flink/core/TestUtils.java       |   3 +-
 .../flink/core/functions/ReductionsTest.java       |   4 +-
 .../flink/core/jsonmodule/JsonModuleTest.java      |   4 +-
 .../flink/core/message/JavaPayloadSerializer.java  |  67 ++++++++++++
 .../statefun/flink/core/message/MessageTest.java   |  21 ++--
 .../message/MessageTypeSerializerSnapshotTest.java | 117 +++++++++++++++++++++
 .../core/message/MessageTypeSerializerTest.java    |   7 +-
 .../operator/FunctionsStateBootstrapOperator.java  |   4 +-
 25 files changed, 474 insertions(+), 74 deletions(-)

diff --git a/docs/deployment-and-operations/configurations.md b/docs/deployment-and-operations/configurations.md
index d09ed25..46b16d3 100644
--- a/docs/deployment-and-operations/configurations.md
+++ b/docs/deployment-and-operations/configurations.md
@@ -47,7 +47,13 @@ These may be set through your job's ``flink-conf.yaml``.
             <td><h5>statefun.message.serializer</h5></td>
             <td style="word-wrap: break-word;">WITH_PROTOBUF_PAYLOADS</td>
             <td>Message Serializer</td>
-            <td>The serializer to use for on the wire messages. Options are WITH_PROTOBUF_PAYLOADS, WITH_KRYO_PAYLOADS, WITH_RAW_PAYLOADS.</td>
+            <td>The serializer to use for on the wire messages. Options are WITH_PROTOBUF_PAYLOADS, WITH_KRYO_PAYLOADS, WITH_RAW_PAYLOADS, WITH_CUSTOM_PAYLOADS.</td>
+        </tr>
+		<tr>
+            <td><h5>statefun.message.custom-payload-serializer-class</h5></td>
+            <td style="word-wrap: break-word;">(none)</td>
+            <td>String</td>
+            <td>The custom payload serializer class to use with the WITH_CUSTOM_PAYLOADS serializer, which must implement MessagePayloadSerializer.</td>
         </tr>
 		<tr>
             <td><h5>statefun.flink-job-name</h5></td>
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfig.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfig.java
index 4982571..a38d9a5 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfig.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfig.java
@@ -30,6 +30,7 @@ import org.apache.flink.configuration.ConfigOptions;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.MemorySize;
 import org.apache.flink.configuration.description.Description;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
 import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
 import org.apache.flink.statefun.sdk.spi.StatefulFunctionModule;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -72,6 +73,13 @@ public class StatefulFunctionsConfig implements Serializable {
           .defaultValue(MessageFactoryType.WITH_PROTOBUF_PAYLOADS)
           .withDescription("The serializer to use for on the wire messages.");
 
+  public static final ConfigOption<String> USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS =
+      ConfigOptions.key("statefun.message.custom-payload-serializer-class")
+          .stringType()
+          .noDefaultValue()
+          .withDescription(
+              "The custom payload serializer class to use with the WITH_CUSTOM_PAYLOADS serializer, which must implement MessagePayloadSerializer.");
+
   public static final ConfigOption<String> FLINK_JOB_NAME =
       ConfigOptions.key("statefun.flink-job-name")
           .stringType()
@@ -114,6 +122,8 @@ public class StatefulFunctionsConfig implements Serializable {
 
   private MessageFactoryType factoryType;
 
+  private String customPayloadSerializerClassName;
+
   private String flinkJobName;
 
   private byte[] universeInitializerClassBytes;
@@ -133,6 +143,8 @@ public class StatefulFunctionsConfig implements Serializable {
    */
   private StatefulFunctionsConfig(Configuration configuration) {
     this.factoryType = configuration.get(USER_MESSAGE_SERIALIZER);
+    this.customPayloadSerializerClassName =
+        configuration.get(USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS);
     this.flinkJobName = configuration.get(FLINK_JOB_NAME);
     this.feedbackBufferSize = configuration.get(TOTAL_MEMORY_USED_FOR_FEEDBACK_CHECKPOINTING);
     this.maxAsyncOperationsPerTask = configuration.get(ASYNC_MAX_OPERATIONS_PER_TASK);
@@ -152,11 +164,28 @@ public class StatefulFunctionsConfig implements Serializable {
     return factoryType;
   }
 
+  /**
+   * Returns the custom payload serializer class name, when factory type is WITH_CUSTOM_PAYLOADS *
+   */
+  public String getCustomPayloadSerializerClassName() {
+    return customPayloadSerializerClassName;
+  }
+
+  /** Returns the factory key * */
+  public MessageFactoryKey getFactoryKey() {
+    return MessageFactoryKey.forType(this.factoryType, this.customPayloadSerializerClassName);
+  }
+
   /** Sets the factory type used to serialize messages. */
   public void setFactoryType(MessageFactoryType factoryType) {
     this.factoryType = Objects.requireNonNull(factoryType);
   }
 
+  /** Sets the custom payload serializer class name * */
+  public void setCustomPayloadSerializerClassName(String customPayloadSerializerClassName) {
+    this.customPayloadSerializerClassName = customPayloadSerializerClassName;
+  }
+
   /** Returns the Flink job name that appears in the Web UI. */
   public String getFlinkJobName() {
     return flinkJobName;
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigValidator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigValidator.java
index c4f658c..2f31d0d 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigValidator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigValidator.java
@@ -27,6 +27,8 @@ import java.util.Set;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.CoreOptions;
 import org.apache.flink.statefun.flink.core.exceptions.StatefulFunctionsInvalidConfigException;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
+import org.apache.flink.util.StringUtils;
 
 public final class StatefulFunctionsConfigValidator {
 
@@ -40,6 +42,7 @@ public final class StatefulFunctionsConfigValidator {
 
   static void validate(Configuration configuration) {
     validateParentFirstClassloaderPatterns(configuration);
+    validateCustomPayloadSerializerClassName(configuration);
   }
 
   private static void validateParentFirstClassloaderPatterns(Configuration configuration) {
@@ -61,4 +64,26 @@ public final class StatefulFunctionsConfigValidator {
     }
     return parentFirstClassloaderPatterns;
   }
+
+  private static void validateCustomPayloadSerializerClassName(Configuration configuration) {
+
+    MessageFactoryType factoryType =
+        configuration.get(StatefulFunctionsConfig.USER_MESSAGE_SERIALIZER);
+    String customPayloadSerializerClassName =
+        configuration.get(StatefulFunctionsConfig.USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS);
+
+    if (factoryType == MessageFactoryType.WITH_CUSTOM_PAYLOADS) {
+      if (StringUtils.isNullOrWhitespaceOnly(customPayloadSerializerClassName)) {
+        throw new StatefulFunctionsInvalidConfigException(
+            StatefulFunctionsConfig.USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS,
+            "custom payload serializer class must be supplied with WITH_CUSTOM_PAYLOADS serializer");
+      }
+    } else {
+      if (customPayloadSerializerClassName != null) {
+        throw new StatefulFunctionsInvalidConfigException(
+            StatefulFunctionsConfig.USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS,
+            "custom payload serializer class may only be supplied with WITH_CUSTOM_PAYLOADS serializer");
+      }
+    }
+  }
 }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsUniverse.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsUniverse.java
index 3b48b3e..bbf6661 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsUniverse.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/StatefulFunctionsUniverse.java
@@ -23,7 +23,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import javax.annotation.Nullable;
-import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
 import org.apache.flink.statefun.flink.core.types.StaticallyRegisteredTypes;
 import org.apache.flink.statefun.flink.io.spi.FlinkIoModule;
 import org.apache.flink.statefun.flink.io.spi.SinkProvider;
@@ -50,11 +50,11 @@ public final class StatefulFunctionsUniverse
   private final Map<EgressType, SinkProvider> sinks = new HashMap<>();
 
   private final StaticallyRegisteredTypes types;
-  private final MessageFactoryType messageFactoryType;
+  private final MessageFactoryKey messageFactoryKey;
 
-  public StatefulFunctionsUniverse(MessageFactoryType messageFactoryType) {
-    this.messageFactoryType = messageFactoryType;
-    this.types = new StaticallyRegisteredTypes(messageFactoryType);
+  public StatefulFunctionsUniverse(MessageFactoryKey messageFactoryKey) {
+    this.messageFactoryKey = messageFactoryKey;
+    this.types = new StaticallyRegisteredTypes(messageFactoryKey);
   }
 
   @Override
@@ -138,7 +138,7 @@ public final class StatefulFunctionsUniverse
         String.format("A binding for the key %s was previously defined.", key));
   }
 
-  public MessageFactoryType messageFactoryType() {
-    return messageFactoryType;
+  public MessageFactoryKey messageFactoryKey() {
+    return messageFactoryKey;
   }
 }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
index 381a672..79e7a53 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
@@ -133,7 +133,7 @@ public class FunctionGroupOperator extends AbstractStreamOperator<Message>
             delayedMessagesBufferState(delayedMessageStateDescriptor),
             sideOutputs,
             output,
-            MessageFactory.forType(statefulFunctionsUniverse.messageFactoryType()),
+            MessageFactory.forKey(statefulFunctionsUniverse.messageFactoryKey()),
             new MailboxExecutorFacade(mailboxExecutor, "Stateful Functions Mailbox"),
             getRuntimeContext().getMetricGroup().addGroup("functions"),
             asyncOperationState);
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
index ee2b643..e4d2d3a 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
@@ -18,6 +18,7 @@
 package org.apache.flink.statefun.flink.core.message;
 
 import java.io.IOException;
+import java.lang.reflect.Constructor;
 import java.util.Objects;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataOutputView;
@@ -29,8 +30,8 @@ import org.apache.flink.statefun.sdk.Address;
 
 public final class MessageFactory {
 
-  public static MessageFactory forType(MessageFactoryType type) {
-    return new MessageFactory(forPayloadType(type));
+  public static MessageFactory forKey(MessageFactoryKey key) {
+    return new MessageFactory(forPayloadKey(key));
   }
 
   private final ProtobufSerializer<Envelope> envelopeSerializer;
@@ -93,8 +94,8 @@ public final class MessageFactory {
     return Envelope.newBuilder().setCheckpoint(checkpoint).build();
   }
 
-  private static MessagePayloadSerializer forPayloadType(MessageFactoryType type) {
-    switch (type) {
+  private static MessagePayloadSerializer forPayloadKey(MessageFactoryKey key) {
+    switch (key.getType()) {
       case WITH_KRYO_PAYLOADS:
         return new MessagePayloadSerializerKryo();
       case WITH_PROTOBUF_PAYLOADS:
@@ -103,8 +104,28 @@ public final class MessageFactory {
         return new MessagePayloadSerializerRaw();
       case WITH_PROTOBUF_PAYLOADS_MULTILANG:
         return new MessagePayloadSerializerMultiLanguage();
+      case WITH_CUSTOM_PAYLOADS:
+        String className =
+            key.getCustomPayloadSerializerClassName()
+                .orElseThrow(
+                    () ->
+                        new UnsupportedOperationException(
+                            "WITH_CUSTOM_PAYLOADS requires custom payload serializer class name to be specified in MessageFactoryKey"));
+        return forCustomPayloadSerializer(className);
       default:
-        throw new IllegalArgumentException("unknown serialization method " + type);
+        throw new IllegalArgumentException("unknown serialization method " + key.getType());
+    }
+  }
+
+  private static MessagePayloadSerializer forCustomPayloadSerializer(String className) {
+    try {
+      Class<?> clazz =
+          Class.forName(className, true, Thread.currentThread().getContextClassLoader());
+      Constructor<?> constructor = clazz.getConstructor();
+      return (MessagePayloadSerializer) constructor.newInstance();
+    } catch (Throwable ex) {
+      throw new UnsupportedOperationException(
+          String.format("Failed to create custom payload serializer: %s", className), ex);
     }
   }
 }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryKey.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryKey.java
new file mode 100644
index 0000000..f5db425
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryKey.java
@@ -0,0 +1,44 @@
+package org.apache.flink.statefun.flink.core.message;
+
+import java.io.Serializable;
+import java.util.Objects;
+import java.util.Optional;
+
+public final class MessageFactoryKey implements Serializable {
+  private static final long serialVersionUID = 1L;
+
+  private final MessageFactoryType type;
+  private final String customPayloadSerializerClassName;
+
+  private MessageFactoryKey(MessageFactoryType type, String customPayloadSerializerClassName) {
+    this.type = Objects.requireNonNull(type);
+    this.customPayloadSerializerClassName = customPayloadSerializerClassName;
+  }
+
+  public static MessageFactoryKey forType(
+      MessageFactoryType type, String customPayloadSerializerClassName) {
+    return new MessageFactoryKey(type, customPayloadSerializerClassName);
+  }
+
+  public MessageFactoryType getType() {
+    return this.type;
+  }
+
+  public Optional<String> getCustomPayloadSerializerClassName() {
+    return Optional.ofNullable(customPayloadSerializerClassName);
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || getClass() != o.getClass()) return false;
+    MessageFactoryKey that = (MessageFactoryKey) o;
+    return type == that.type
+        && Objects.equals(customPayloadSerializerClassName, that.customPayloadSerializerClassName);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(type, customPayloadSerializerClassName);
+  }
+}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryType.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryType.java
index 17e5a8b..404e5ed 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryType.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactoryType.java
@@ -21,5 +21,6 @@ public enum MessageFactoryType {
   WITH_KRYO_PAYLOADS,
   WITH_PROTOBUF_PAYLOADS,
   WITH_RAW_PAYLOADS,
-  WITH_PROTOBUF_PAYLOADS_MULTILANG
+  WITH_PROTOBUF_PAYLOADS_MULTILANG,
+  WITH_CUSTOM_PAYLOADS,
 }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeInformation.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeInformation.java
index 277b04d..90e8dd0 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeInformation.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeInformation.java
@@ -24,12 +24,12 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 
 public class MessageTypeInformation extends TypeInformation<Message> {
 
-  private static final long serialVersionUID = 1L;
+  private static final long serialVersionUID = 2L;
 
-  private final MessageFactoryType messageFactoryType;
+  private final MessageFactoryKey messageFactoryKey;
 
-  public MessageTypeInformation(MessageFactoryType messageFactoryType) {
-    this.messageFactoryType = Objects.requireNonNull(messageFactoryType);
+  public MessageTypeInformation(MessageFactoryKey messageFactoryKey) {
+    this.messageFactoryKey = Objects.requireNonNull(messageFactoryKey);
   }
 
   @Override
@@ -64,12 +64,14 @@ public class MessageTypeInformation extends TypeInformation<Message> {
 
   @Override
   public TypeSerializer<Message> createSerializer(ExecutionConfig executionConfig) {
-    return new MessageTypeSerializer(messageFactoryType);
+    return new MessageTypeSerializer(messageFactoryKey);
   }
 
   @Override
   public String toString() {
-    return "MessageTypeInformation(" + messageFactoryType + ")";
+    return String.format(
+        "MessageTypeInformation(%s: %s",
+        messageFactoryKey.getType(), messageFactoryKey.getCustomPayloadSerializerClassName());
   }
 
   @Override
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializer.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializer.java
index 53c3265..6e78a5d 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializer.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializer.java
@@ -19,6 +19,7 @@ package org.apache.flink.statefun.flink.core.message;
 
 import java.io.IOException;
 import java.util.Objects;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
 import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
@@ -27,16 +28,16 @@ import org.apache.flink.core.memory.DataOutputView;
 
 public final class MessageTypeSerializer extends TypeSerializer<Message> {
 
-  private static final long serialVersionUID = 1L;
+  private static final long serialVersionUID = 2L;
 
   // -- configuration --
-  private final MessageFactoryType messageFactoryType;
+  private final MessageFactoryKey messageFactoryKey;
 
   // -- runtime --
   private transient MessageFactory factory;
 
-  MessageTypeSerializer(MessageFactoryType messageFactoryType) {
-    this.messageFactoryType = Objects.requireNonNull(messageFactoryType);
+  MessageTypeSerializer(MessageFactoryKey messageFactoryKey) {
+    this.messageFactoryKey = Objects.requireNonNull(messageFactoryKey);
   }
 
   @Override
@@ -46,7 +47,7 @@ public final class MessageTypeSerializer extends TypeSerializer<Message> {
 
   @Override
   public TypeSerializer<Message> duplicate() {
-    return new MessageTypeSerializer(messageFactoryType);
+    return new MessageTypeSerializer(messageFactoryKey);
   }
 
   @Override
@@ -101,45 +102,67 @@ public final class MessageTypeSerializer extends TypeSerializer<Message> {
 
   @Override
   public TypeSerializerSnapshot<Message> snapshotConfiguration() {
-    return new Snapshot(messageFactoryType);
+    return new Snapshot(messageFactoryKey);
   }
 
   private MessageFactory factory() {
     if (factory == null) {
-      factory = MessageFactory.forType(messageFactoryType);
+      factory = MessageFactory.forKey(messageFactoryKey);
     }
     return factory;
   }
 
   public static final class Snapshot implements TypeSerializerSnapshot<Message> {
-    private MessageFactoryType messageFactoryType;
+    private MessageFactoryKey messageFactoryKey;
 
     @SuppressWarnings("unused")
     public Snapshot() {}
 
-    Snapshot(MessageFactoryType messageFactoryType) {
-      this.messageFactoryType = messageFactoryType;
+    Snapshot(MessageFactoryKey messageFactoryKey) {
+      this.messageFactoryKey = messageFactoryKey;
+    }
+
+    @VisibleForTesting
+    MessageFactoryKey getMessageFactoryKey() {
+      return messageFactoryKey;
     }
 
     @Override
     public int getCurrentVersion() {
-      return 1;
+      return 2;
     }
 
     @Override
     public void writeSnapshot(DataOutputView dataOutputView) throws IOException {
-      dataOutputView.writeUTF(messageFactoryType.name());
+
+      // version 1
+      dataOutputView.writeUTF(messageFactoryKey.getType().name());
+
+      // added in version 2
+      writeNullableString(
+          messageFactoryKey.getCustomPayloadSerializerClassName().orElse(null), dataOutputView);
     }
 
     @Override
     public void readSnapshot(int version, DataInputView dataInputView, ClassLoader classLoader)
         throws IOException {
-      messageFactoryType = MessageFactoryType.valueOf(dataInputView.readUTF());
+
+      // read values and assign defaults appropriate for version 1
+      MessageFactoryType messageFactoryType = MessageFactoryType.valueOf(dataInputView.readUTF());
+      String customPayloadSerializerClassName = null;
+
+      // if at least version 2, read in the custom payload serializer class name
+      if (version >= 2) {
+        customPayloadSerializerClassName = readNullableString(dataInputView);
+      }
+
+      this.messageFactoryKey =
+          MessageFactoryKey.forType(messageFactoryType, customPayloadSerializerClassName);
     }
 
     @Override
     public TypeSerializer<Message> restoreSerializer() {
-      return new MessageTypeSerializer(messageFactoryType);
+      return new MessageTypeSerializer(messageFactoryKey);
     }
 
     @Override
@@ -149,10 +172,28 @@ public final class MessageTypeSerializer extends TypeSerializer<Message> {
         return TypeSerializerSchemaCompatibility.incompatible();
       }
       MessageTypeSerializer casted = (MessageTypeSerializer) typeSerializer;
-      if (casted.messageFactoryType == messageFactoryType) {
+      if (casted.messageFactoryKey.equals(messageFactoryKey)) {
         return TypeSerializerSchemaCompatibility.compatibleAsIs();
       }
       return TypeSerializerSchemaCompatibility.incompatible();
     }
+
+    private static void writeNullableString(String value, DataOutputView out) throws IOException {
+      if (value != null) {
+        out.writeBoolean(true);
+        out.writeUTF(value);
+      } else {
+        out.writeBoolean(false);
+      }
+    }
+
+    private static String readNullableString(DataInputView in) throws IOException {
+      boolean isPresent = in.readBoolean();
+      if (isPresent) {
+        return in.readUTF();
+      } else {
+        return null;
+      }
+    }
   }
 }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/spi/Modules.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/spi/Modules.java
index 1f01986..6ca5d24 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/spi/Modules.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/spi/Modules.java
@@ -22,7 +22,7 @@ import org.apache.flink.statefun.flink.common.SetContextClassLoader;
 import org.apache.flink.statefun.flink.core.StatefulFunctionsConfig;
 import org.apache.flink.statefun.flink.core.StatefulFunctionsUniverse;
 import org.apache.flink.statefun.flink.core.jsonmodule.JsonServiceLoader;
-import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
 import org.apache.flink.statefun.flink.io.spi.FlinkIoModule;
 import org.apache.flink.statefun.sdk.spi.StatefulFunctionModule;
 
@@ -54,9 +54,9 @@ public final class Modules {
 
   public StatefulFunctionsUniverse createStatefulFunctionsUniverse(
       StatefulFunctionsConfig configuration) {
-    MessageFactoryType factoryType = configuration.getFactoryType();
+    MessageFactoryKey factoryKey = configuration.getFactoryKey();
 
-    StatefulFunctionsUniverse universe = new StatefulFunctionsUniverse(factoryType);
+    StatefulFunctionsUniverse universe = new StatefulFunctionsUniverse(factoryKey);
 
     final Map<String, String> globalConfiguration = configuration.getGlobalConfigurations();
 
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/CheckpointToMessage.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/CheckpointToMessage.java
index 29eb07c..3dd2cc3 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/CheckpointToMessage.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/CheckpointToMessage.java
@@ -21,17 +21,17 @@ import java.io.Serializable;
 import java.util.function.LongFunction;
 import org.apache.flink.statefun.flink.core.message.Message;
 import org.apache.flink.statefun.flink.core.message.MessageFactory;
-import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
 
 final class CheckpointToMessage implements Serializable, LongFunction<Message> {
 
-  private static final long serialVersionUID = 1L;
+  private static final long serialVersionUID = 2L;
 
-  private final MessageFactoryType messageFactoryType;
+  private final MessageFactoryKey messageFactoryKey;
   private transient MessageFactory factory;
 
-  CheckpointToMessage(MessageFactoryType messageFactoryType) {
-    this.messageFactoryType = messageFactoryType;
+  CheckpointToMessage(MessageFactoryKey messageFactoryKey) {
+    this.messageFactoryKey = messageFactoryKey;
   }
 
   @Override
@@ -41,7 +41,7 @@ final class CheckpointToMessage implements Serializable, LongFunction<Message> {
 
   private MessageFactory factory() {
     if (factory == null) {
-      factory = MessageFactory.forType(messageFactoryType);
+      factory = MessageFactory.forKey(messageFactoryKey);
     }
     return factory;
   }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/EmbeddedTranslator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/EmbeddedTranslator.java
index 01df738..2e09158 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/EmbeddedTranslator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/EmbeddedTranslator.java
@@ -51,7 +51,7 @@ public class EmbeddedTranslator {
 
     configuration.setProvider(new EmbeddedUniverseProvider<>(functions));
 
-    StaticallyRegisteredTypes types = new StaticallyRegisteredTypes(configuration.getFactoryType());
+    StaticallyRegisteredTypes types = new StaticallyRegisteredTypes(configuration.getFactoryKey());
     Sources sources = Sources.create(types, ingresses);
     Sinks sinks = Sinks.create(types, egressesIds);
 
@@ -75,7 +75,7 @@ public class EmbeddedTranslator {
     @Override
     public StatefulFunctionsUniverse get(
         ClassLoader classLoader, StatefulFunctionsConfig configuration) {
-      StatefulFunctionsUniverse u = new StatefulFunctionsUniverse(configuration.getFactoryType());
+      StatefulFunctionsUniverse u = new StatefulFunctionsUniverse(configuration.getFactoryKey());
       functions.forEach(u::bindFunctionProvider);
       return u;
     }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/IngressRouterOperator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/IngressRouterOperator.java
index b2de476..dc002bc 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/IngressRouterOperator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/IngressRouterOperator.java
@@ -27,6 +27,7 @@ import org.apache.flink.statefun.flink.core.StatefulFunctionsUniverse;
 import org.apache.flink.statefun.flink.core.StatefulFunctionsUniverses;
 import org.apache.flink.statefun.flink.core.message.Message;
 import org.apache.flink.statefun.flink.core.message.MessageFactory;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
 import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
 import org.apache.flink.statefun.sdk.Address;
 import org.apache.flink.statefun.sdk.io.IngressIdentifier;
@@ -67,9 +68,9 @@ public final class IngressRouterOperator<T> extends AbstractStreamOperator<Messa
         StatefulFunctionsUniverses.get(
             Thread.currentThread().getContextClassLoader(), configuration);
 
-    LOG.info("Using message factory type " + universe.messageFactoryType());
+    LOG.info("Using message factory key " + universe.messageFactoryKey());
 
-    this.downstream = new DownstreamCollector<>(universe.messageFactoryType(), output);
+    this.downstream = new DownstreamCollector<>(universe.messageFactoryKey(), output);
     this.routers = loadRoutersAttachedToIngress(id, universe.routers());
   }
 
@@ -98,12 +99,11 @@ public final class IngressRouterOperator<T> extends AbstractStreamOperator<Messa
     private final StreamRecord<Message> reuse = new StreamRecord<>(null);
     private final Output<StreamRecord<Message>> output;
 
-    DownstreamCollector(
-        MessageFactoryType messageFactoryType, Output<StreamRecord<Message>> output) {
-      this.factory = MessageFactory.forType(messageFactoryType);
+    DownstreamCollector(MessageFactoryKey messageFactoryKey, Output<StreamRecord<Message>> output) {
+      this.factory = MessageFactory.forKey(messageFactoryKey);
       this.output = Objects.requireNonNull(output);
       this.multiLanguagePayloads =
-          messageFactoryType == MessageFactoryType.WITH_PROTOBUF_PAYLOADS_MULTILANG;
+          messageFactoryKey.getType() == MessageFactoryType.WITH_PROTOBUF_PAYLOADS_MULTILANG;
     }
 
     @Override
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/StatefulFunctionTranslator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/StatefulFunctionTranslator.java
index 48b1bca..5ea77d3 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/StatefulFunctionTranslator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/translation/StatefulFunctionTranslator.java
@@ -91,7 +91,7 @@ final class StatefulFunctionTranslator {
   private SingleOutputStreamOperator<Void> feedbackOperator(
       SingleOutputStreamOperator<Message> functionOut) {
 
-    LongFunction<Message> toMessage = new CheckpointToMessage(configuration.getFactoryType());
+    LongFunction<Message> toMessage = new CheckpointToMessage(configuration.getFactoryKey());
 
     FeedbackSinkOperator<Message> sinkOperator = new FeedbackSinkOperator<>(feedbackKey, toMessage);
 
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/types/StaticallyRegisteredTypes.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/types/StaticallyRegisteredTypes.java
index dd2c743..2b976c7 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/types/StaticallyRegisteredTypes.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/types/StaticallyRegisteredTypes.java
@@ -24,7 +24,7 @@ import javax.annotation.Nullable;
 import javax.annotation.concurrent.NotThreadSafe;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.statefun.flink.common.protobuf.ProtobufTypeInformation;
-import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
 import org.apache.flink.statefun.flink.core.message.MessageTypeInformation;
 
 /**
@@ -37,11 +37,11 @@ public final class StaticallyRegisteredTypes {
 
   private final Map<Class<?>, TypeInformation<?>> registeredTypes = new HashMap<>();
 
-  public StaticallyRegisteredTypes(MessageFactoryType messageFactoryType) {
-    this.messageFactoryType = messageFactoryType;
+  public StaticallyRegisteredTypes(MessageFactoryKey messageFactoryKey) {
+    this.messageFactoryKey = messageFactoryKey;
   }
 
-  private final MessageFactoryType messageFactoryType;
+  private final MessageFactoryKey messageFactoryKey;
 
   public <T> TypeInformation<T> registerType(Class<T> type) {
     return (TypeInformation<T>) registeredTypes.computeIfAbsent(type, this::typeInformation);
@@ -62,7 +62,7 @@ public final class StaticallyRegisteredTypes {
       return new ProtobufTypeInformation<>(message);
     }
     if (org.apache.flink.statefun.flink.core.message.Message.class.isAssignableFrom(valueType)) {
-      return new MessageTypeInformation(messageFactoryType);
+      return new MessageTypeInformation(messageFactoryKey);
     }
     // TODO: we may want to restrict the allowed typeInfo here to theses that respect shcema
     // evaluation.
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigTest.java
index baa704b..dc794bb 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/StatefulFunctionsConfigTest.java
@@ -17,9 +17,11 @@
  */
 package org.apache.flink.statefun.flink.core;
 
+import java.util.Optional;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.CoreOptions;
 import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.statefun.flink.core.exceptions.StatefulFunctionsInvalidConfigException;
 import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
 import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
 import org.hamcrest.Matchers;
@@ -28,6 +30,8 @@ import org.junit.Test;
 
 public class StatefulFunctionsConfigTest {
 
+  private final String serializerClassName = "com.sample.Serializer";
+
   @Test
   public void testSetConfigurations() {
     final String testName = "test-name";
@@ -35,7 +39,9 @@ public class StatefulFunctionsConfigTest {
     Configuration configuration = new Configuration();
     configuration.set(StatefulFunctionsConfig.FLINK_JOB_NAME, testName);
     configuration.set(
-        StatefulFunctionsConfig.USER_MESSAGE_SERIALIZER, MessageFactoryType.WITH_KRYO_PAYLOADS);
+        StatefulFunctionsConfig.USER_MESSAGE_SERIALIZER, MessageFactoryType.WITH_CUSTOM_PAYLOADS);
+    configuration.set(
+        StatefulFunctionsConfig.USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS, serializerClassName);
     configuration.set(
         StatefulFunctionsConfig.TOTAL_MEMORY_USED_FOR_FEEDBACK_CHECKPOINTING,
         MemorySize.ofMebiBytes(100));
@@ -51,7 +57,11 @@ public class StatefulFunctionsConfigTest {
         StatefulFunctionsConfig.fromFlinkConfiguration(configuration);
 
     Assert.assertEquals(stateFunConfig.getFlinkJobName(), testName);
-    Assert.assertEquals(stateFunConfig.getFactoryType(), MessageFactoryType.WITH_KRYO_PAYLOADS);
+    Assert.assertEquals(
+        stateFunConfig.getFactoryKey().getType(), MessageFactoryType.WITH_CUSTOM_PAYLOADS);
+    Assert.assertEquals(
+        stateFunConfig.getFactoryKey().getCustomPayloadSerializerClassName(),
+        Optional.of(serializerClassName));
     Assert.assertEquals(stateFunConfig.getFeedbackBufferSize(), MemorySize.ofMebiBytes(100));
     Assert.assertEquals(stateFunConfig.getMaxAsyncOperationsPerTask(), 100);
     Assert.assertThat(
@@ -60,7 +70,7 @@ public class StatefulFunctionsConfigTest {
         stateFunConfig.getGlobalConfigurations(), Matchers.hasEntry("key2", "value2"));
   }
 
-  private static Configuration validConfiguration() {
+  private static Configuration baseConfiguration() {
     Configuration configuration = new Configuration();
     configuration.set(StatefulFunctionsConfig.FLINK_JOB_NAME, "name");
     configuration.set(
@@ -75,4 +85,22 @@ public class StatefulFunctionsConfigTest {
     configuration.set(ExecutionCheckpointingOptions.MAX_CONCURRENT_CHECKPOINTS, 1);
     return configuration;
   }
+
+  @Test(expected = StatefulFunctionsInvalidConfigException.class)
+  public void invalidCustomSerializerThrows() {
+    Configuration configuration = baseConfiguration();
+    configuration.set(
+        StatefulFunctionsConfig.USER_MESSAGE_SERIALIZER, MessageFactoryType.WITH_CUSTOM_PAYLOADS);
+    StatefulFunctionsConfigValidator.validate(configuration);
+  }
+
+  @Test(expected = StatefulFunctionsInvalidConfigException.class)
+  public void invalidNonCustomSerializerThrows() {
+    Configuration configuration = baseConfiguration();
+    configuration.set(
+        StatefulFunctionsConfig.USER_MESSAGE_SERIALIZER, MessageFactoryType.WITH_KRYO_PAYLOADS);
+    configuration.set(
+        StatefulFunctionsConfig.USER_MESSAGE_CUSTOM_PAYLOAD_SERIALIZER_CLASS, serializerClassName);
+    StatefulFunctionsConfigValidator.validate(configuration);
+  }
 }
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/TestUtils.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/TestUtils.java
index ac3ad88..c49270e 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/TestUtils.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/TestUtils.java
@@ -19,6 +19,7 @@ package org.apache.flink.statefun.flink.core;
 
 import org.apache.flink.statefun.flink.core.generated.EnvelopeAddress;
 import org.apache.flink.statefun.flink.core.message.MessageFactory;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
 import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
 import org.apache.flink.statefun.sdk.Address;
 import org.apache.flink.statefun.sdk.FunctionType;
@@ -27,7 +28,7 @@ import org.apache.flink.statefun.sdk.FunctionType;
 public class TestUtils {
 
   public static final MessageFactory ENVELOPE_FACTORY =
-      MessageFactory.forType(MessageFactoryType.WITH_KRYO_PAYLOADS);
+      MessageFactory.forKey(MessageFactoryKey.forType(MessageFactoryType.WITH_KRYO_PAYLOADS, null));
 
   public static final FunctionType FUNCTION_TYPE = new FunctionType("test", "a");
   public static final Address FUNCTION_1_ADDR = new Address(FUNCTION_TYPE, "a-1");
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
index bdb67e3..d389841 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
@@ -73,6 +73,7 @@ import org.apache.flink.statefun.flink.core.StatefulFunctionsUniverse;
 import org.apache.flink.statefun.flink.core.TestUtils;
 import org.apache.flink.statefun.flink.core.backpressure.ThresholdBackPressureValve;
 import org.apache.flink.statefun.flink.core.message.Message;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
 import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
 import org.apache.flink.streaming.api.operators.InternalTimerService;
 import org.apache.flink.streaming.api.operators.Output;
@@ -91,7 +92,8 @@ public class ReductionsTest {
     Reductions reductions =
         Reductions.create(
             new ThresholdBackPressureValve(-1),
-            new StatefulFunctionsUniverse(MessageFactoryType.WITH_KRYO_PAYLOADS),
+            new StatefulFunctionsUniverse(
+                MessageFactoryKey.forType(MessageFactoryType.WITH_KRYO_PAYLOADS, null)),
             new FakeRuntimeContext(),
             new FakeKeyedStateBackend(),
             new FakeTimerServiceFactory(),
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java
index 44b5bbe..3001684 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java
@@ -28,6 +28,7 @@ import java.util.Collection;
 import java.util.Collections;
 import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
 import org.apache.flink.statefun.flink.core.StatefulFunctionsUniverse;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
 import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
 import org.apache.flink.statefun.sdk.FunctionType;
 import org.apache.flink.statefun.sdk.io.EgressIdentifier;
@@ -118,6 +119,7 @@ public class JsonModuleTest {
   }
 
   private static StatefulFunctionsUniverse emptyUniverse() {
-    return new StatefulFunctionsUniverse(MessageFactoryType.WITH_PROTOBUF_PAYLOADS);
+    return new StatefulFunctionsUniverse(
+        MessageFactoryKey.forType(MessageFactoryType.WITH_PROTOBUF_PAYLOADS, null));
   }
 }
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/JavaPayloadSerializer.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/JavaPayloadSerializer.java
new file mode 100644
index 0000000..bb686c3
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/JavaPayloadSerializer.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.statefun.flink.core.message;
+
+import com.google.protobuf.ByteString;
+import java.io.*;
+import javax.annotation.Nonnull;
+import org.apache.flink.statefun.flink.core.generated.Payload;
+
+// this is a payload serializer that uses normal java serialization, used for testing custom payload
+// serialization
+public class JavaPayloadSerializer implements MessagePayloadSerializer {
+
+  @Override
+  public Payload serialize(@Nonnull Object payloadObject) {
+    try {
+      String className = payloadObject.getClass().getName();
+      try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
+        try (ObjectOutputStream out = new ObjectOutputStream(bos)) {
+          out.writeObject(payloadObject);
+          out.flush();
+          byte[] bytes = bos.toByteArray();
+          return Payload.newBuilder()
+              .setClassName(className)
+              .setPayloadBytes(ByteString.copyFrom(bytes))
+              .build();
+        }
+      }
+    } catch (Throwable ex) {
+      throw new RuntimeException(ex);
+    }
+  }
+
+  @Override
+  public Object deserialize(@Nonnull ClassLoader targetClassLoader, @Nonnull Payload payload) {
+    try {
+      try (ByteArrayInputStream bis =
+          new ByteArrayInputStream(payload.getPayloadBytes().toByteArray())) {
+        try (ObjectInput in = new ObjectInputStream(bis)) {
+          return in.readObject();
+        }
+      }
+    } catch (Throwable ex) {
+      throw new RuntimeException(ex);
+    }
+  }
+
+  @Override
+  public Object copy(@Nonnull ClassLoader targetClassLoader, @Nonnull Object what) {
+    return deserialize(targetClassLoader, serialize(what));
+  }
+}
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTest.java
index 04056ad..31a9c35 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTest.java
@@ -34,27 +34,36 @@ import org.junit.runners.Parameterized.Parameters;
 @RunWith(Parameterized.class)
 public class MessageTest {
   private final MessageFactoryType type;
+  private final String customPayloadSerializerClassName;
   private final Object payload;
 
-  public MessageTest(MessageFactoryType type, Object payload) {
+  public MessageTest(
+      MessageFactoryType type, String customPayloadSerializerClassName, Object payload) {
     this.type = type;
+    this.customPayloadSerializerClassName = customPayloadSerializerClassName;
     this.payload = payload;
   }
 
   @Parameters(name = "{0}")
   public static Iterable<? extends Object[]> data() {
     return Arrays.asList(
-        new Object[] {MessageFactoryType.WITH_KRYO_PAYLOADS, DUMMY_PAYLOAD},
-        new Object[] {MessageFactoryType.WITH_PROTOBUF_PAYLOADS, DUMMY_PAYLOAD},
-        new Object[] {MessageFactoryType.WITH_RAW_PAYLOADS, DUMMY_PAYLOAD.toByteArray()},
+        new Object[] {MessageFactoryType.WITH_KRYO_PAYLOADS, null, DUMMY_PAYLOAD},
+        new Object[] {MessageFactoryType.WITH_PROTOBUF_PAYLOADS, null, DUMMY_PAYLOAD},
+        new Object[] {MessageFactoryType.WITH_RAW_PAYLOADS, null, DUMMY_PAYLOAD.toByteArray()},
         new Object[] {
-          MessageFactoryType.WITH_PROTOBUF_PAYLOADS_MULTILANG, Any.pack(DUMMY_PAYLOAD)
+          MessageFactoryType.WITH_PROTOBUF_PAYLOADS_MULTILANG, null, Any.pack(DUMMY_PAYLOAD)
+        },
+        new Object[] {
+          MessageFactoryType.WITH_CUSTOM_PAYLOADS,
+          "org.apache.flink.statefun.flink.core.message.JavaPayloadSerializer",
+          DUMMY_PAYLOAD
         });
   }
 
   @Test
   public void roundTrip() throws IOException {
-    MessageFactory factory = MessageFactory.forType(type);
+    MessageFactory factory =
+        MessageFactory.forKey(MessageFactoryKey.forType(type, customPayloadSerializerClassName));
 
     Message fromSdk = factory.from(FUNCTION_1_ADDR, FUNCTION_2_ADDR, payload);
     DataOutputSerializer out = new DataOutputSerializer(32);
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerSnapshotTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerSnapshotTest.java
new file mode 100644
index 0000000..f9c0c78
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerSnapshotTest.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.statefun.flink.core.message;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.Arrays;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class MessageTypeSerializerSnapshotTest {
+
+  private static final String serializerClassName = "com.domain.Serializer";
+
+  private static class SnapshotData {
+    public int version;
+    public byte[] bytes;
+  }
+
+  private static interface SnapshotDataProvider {
+    SnapshotData provide(MessageFactoryKey messageFactoryKey) throws IOException;
+  }
+
+  private final MessageFactoryKey messageFactoryKey;
+  private final SnapshotDataProvider snapshotDataProvider;
+
+  public MessageTypeSerializerSnapshotTest(
+      MessageFactoryKey messageFactoryKey, SnapshotDataProvider snapshotDataProvider) {
+    this.messageFactoryKey = messageFactoryKey;
+    this.snapshotDataProvider = snapshotDataProvider;
+  }
+
+  @Parameterized.Parameters(name = "{0}")
+  public static Iterable<? extends Object[]> data() throws IOException {
+
+    MessageFactoryKey kryoFactoryKey =
+        MessageFactoryKey.forType(MessageFactoryType.WITH_KRYO_PAYLOADS, null);
+    MessageFactoryKey customFactoryKey =
+        MessageFactoryKey.forType(MessageFactoryType.WITH_CUSTOM_PAYLOADS, serializerClassName);
+
+    // generates snapshot data for V1, without customPayloadSerializerClassName
+    SnapshotDataProvider snapshotDataProviderV1 =
+        messageFactoryKey -> {
+          try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
+            DataOutputView dataOutputView = new DataOutputViewStreamWrapper(bos);
+            dataOutputView.writeUTF(messageFactoryKey.getType().name());
+            return new SnapshotData() {
+              {
+                version = 1;
+                bytes = bos.toByteArray();
+              }
+            };
+          }
+        };
+
+    // generates snapshot data for V2, the current version
+    SnapshotDataProvider snapshotDataProviderV2 =
+        messageFactoryKey -> {
+          MessageTypeSerializer.Snapshot snapshot =
+              new MessageTypeSerializer.Snapshot(messageFactoryKey);
+          try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
+            DataOutputView dataOutputView = new DataOutputViewStreamWrapper(bos);
+            snapshot.writeSnapshot(dataOutputView);
+            return new SnapshotData() {
+              {
+                version = 2;
+                bytes = bos.toByteArray();
+              }
+            };
+          }
+        };
+
+    return Arrays.asList(
+        new Object[] {kryoFactoryKey, snapshotDataProviderV1},
+        new Object[] {kryoFactoryKey, snapshotDataProviderV2},
+        new Object[] {customFactoryKey, snapshotDataProviderV2});
+  }
+
+  @Test
+  public void roundTrip() throws IOException {
+
+    SnapshotData snapshotData = this.snapshotDataProvider.provide(this.messageFactoryKey);
+    MessageTypeSerializer.Snapshot snapshot =
+        new MessageTypeSerializer.Snapshot(this.messageFactoryKey);
+    ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
+
+    try (ByteArrayInputStream bis = new ByteArrayInputStream(snapshotData.bytes)) {
+      DataInputView dataInputView = new DataInputViewStreamWrapper(bis);
+      snapshot.readSnapshot(snapshotData.version, dataInputView, classLoader);
+    }
+
+    // make sure the deserialized state matches what was used to serialize
+    assert (snapshot.getMessageFactoryKey().equals(this.messageFactoryKey));
+  }
+}
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerTest.java
index 80c1945..89c7726 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/message/MessageTypeSerializerTest.java
@@ -39,7 +39,9 @@ public class MessageTypeSerializerTest extends SerializerTestBase<Message> {
             Message b = (Message) o2;
             DataOutputSerializer aOut = new DataOutputSerializer(32);
             DataOutputSerializer bOut = new DataOutputSerializer(32);
-            MessageFactory factory = MessageFactory.forType(MessageFactoryType.WITH_KRYO_PAYLOADS);
+            MessageFactory factory =
+                MessageFactory.forKey(
+                    MessageFactoryKey.forType(MessageFactoryType.WITH_KRYO_PAYLOADS, null));
             try {
               a.writeTo(factory, aOut);
             } catch (IOException e) {
@@ -57,7 +59,8 @@ public class MessageTypeSerializerTest extends SerializerTestBase<Message> {
 
   @Override
   protected TypeSerializer<Message> createSerializer() {
-    return new MessageTypeInformation(MessageFactoryType.WITH_KRYO_PAYLOADS)
+    return new MessageTypeInformation(
+            MessageFactoryKey.forType(MessageFactoryType.WITH_KRYO_PAYLOADS, null))
         .createSerializer(new ExecutionConfig());
   }
 
diff --git a/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/FunctionsStateBootstrapOperator.java b/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/FunctionsStateBootstrapOperator.java
index 09e6b59..ec33811 100644
--- a/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/FunctionsStateBootstrapOperator.java
+++ b/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/FunctionsStateBootstrapOperator.java
@@ -24,6 +24,7 @@ import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.state.api.output.SnapshotUtils;
 import org.apache.flink.state.api.output.TaggedOperatorSubtaskState;
 import org.apache.flink.statefun.flink.core.functions.FunctionGroupOperator;
+import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
 import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
 import org.apache.flink.statefun.flink.core.state.FlinkState;
 import org.apache.flink.statefun.flink.core.state.State;
@@ -95,6 +96,7 @@ public final class FunctionsStateBootstrapOperator
         runtimeContext,
         keyedStateBackend,
         new DynamicallyRegisteredTypes(
-            new StaticallyRegisteredTypes(MessageFactoryType.WITH_RAW_PAYLOADS)));
+            new StaticallyRegisteredTypes(
+                MessageFactoryKey.forType(MessageFactoryType.WITH_RAW_PAYLOADS, null))));
   }
 }