You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by jo...@apache.org on 2023/12/07 15:11:23 UTC

(beam) branch master updated: Add Error Handling to Kafka IO (#29546)

This is an automated email from the ASF dual-hosted git repository.

johncasey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 5d11c20cd4e Add Error Handling to Kafka IO (#29546)
5d11c20cd4e is described below

commit 5d11c20cd4e5a0df132f4cd2cff0d880a0d75323
Author: johnjcasey <95...@users.noreply.github.com>
AuthorDate: Thu Dec 7 10:11:16 2023 -0500

    Add Error Handling to Kafka IO (#29546)
    
    * Update 2.50 release notes to include new Kafka topicPattern feature
    
    * Create groovy class for io performance tests
    Create gradle task and github actions config for GCS using this.
    
    * delete unnecessary class
    
    * fix env call
    
    * fix call to gradle
    
    * run on hosted runner for testing
    
    * add additional checkout
    
    * add destination for triggered tests
    
    * move env variables to correct location
    
    * try uploading against separate dataset
    
    * try without a user
    
    * update branch checkout, try to view the failure log
    
    * run on failure
    
    * update to use correct BigQuery instance
    
    * convert to matrix
    
    * add result reporting
    
    * add failure clause
    
    * remove failure clause, update to run on self-hosted
    
    * address comments, clean up build
    
    * clarify branching
    
    * Add error handling base implementation & test DLQ enabled class
    
    * Add test cases
    
    * apply spotless
    
    * Fix Checkstyles
    
    * Fix Checkstyles
    
    * make DLH serializable
    
    * rename dead letter to bad record
    
    * make DLH serializable
    
    * Change bad record router name, and use multioutputreceiver instead of process context
    
    * Refactor BadRecord to be nested
    
    * clean up checkstyle
    
    * Update error handler test
    
    * Add metrics for counting error records, and for measuring feature usage
    
    * apply spotless
    
    * fix checkstyle
    
    * make metric reporting static
    
    * spotless
    
    * Rework annotations to be an explicit label on a PTransform, instead of using java annotations
    
    * fix checkstyle
    
    * Address comments
    
    * Address comments
    
    * Fix test cases, spotless
    
    * remove flatting without error collections
    
    * fix nullness
    
    * spotless + encoding issues
    
    * spotless
    
    * throw error when error handler isn't used
    
    * add concrete bad record error handler class
    
    * spotless, fix test category
    
    * fix checkstyle
    
    * clean up comments
    
    * fix test case
    
    * initial wiring of error handler into KafkaIO Read
    
    * remove "failing transform" field on bad record, add note to CHANGES.md
    
    * fix failing test cases
    
    * fix failing test cases
    
    * apply spotless
    
    * Add tests
    
    * Add tests
    
    * fix test case
    
    * add documentation
    
    * wire error handler into kafka write
    
    * fix failing test case
    
    * Add tests for writing to kafka with exception handling
    
    * fix sdf testing
    
    * fix sdf testing
    
    * spotless
    
    * deflake tests
    
    * add error handling to kafka streaming example
    
    update error handler to be serializable to support using it as a member of an auto-value based PTransform
    
    * apply final comments
    
    * apply final comments
    
    * apply final comments
    
    * add line to CHANGES.md
    
    * fix spotless
    
    * fix checkstyle
    
    * make sink transform static for serialization
    
    * spotless
    
    * fix typo
    
    * fix typo
    
    * fix spotbugs
---
 CHANGES.md                                         |   1 +
 .../beam/gradle/kafka/KafkaTestUtilities.groovy    |   3 +-
 .../org/apache/beam/examples/KafkaStreaming.java   |  67 ++++++--
 .../sdk/transforms/errorhandling/ErrorHandler.java |  32 ++--
 sdks/java/io/kafka/kafka-01103/build.gradle        |   1 +
 sdks/java/io/kafka/kafka-100/build.gradle          |   3 +-
 sdks/java/io/kafka/kafka-111/build.gradle          |   1 +
 sdks/java/io/kafka/kafka-201/build.gradle          |   1 +
 sdks/java/io/kafka/kafka-211/build.gradle          |   1 +
 sdks/java/io/kafka/kafka-222/build.gradle          |   1 +
 sdks/java/io/kafka/kafka-231/build.gradle          |   1 +
 sdks/java/io/kafka/kafka-241/build.gradle          |   1 +
 sdks/java/io/kafka/kafka-251/build.gradle          |   1 +
 sdks/java/io/kafka/kafka-integration-test.gradle   |   2 +-
 .../java/org/apache/beam/sdk/io/kafka/KafkaIO.java | 127 ++++++++++++++--
 .../KafkaIOReadImplementationCompatibility.java    |   1 +
 .../org/apache/beam/sdk/io/kafka/KafkaWriter.java  |  46 ++++--
 .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java       | 110 +++++++++-----
 .../beam/sdk/io/kafka/KafkaIOExternalTest.java     |   8 +-
 .../org/apache/beam/sdk/io/kafka/KafkaIOIT.java    | 121 +++++++++++----
 .../org/apache/beam/sdk/io/kafka/KafkaIOTest.java  | 114 ++++++++++++--
 .../beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java   | 168 ++++++++++++++++++---
 22 files changed, 659 insertions(+), 152 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 4b977bf3790..7686b7a92d9 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -66,6 +66,7 @@
 * TextIO now supports skipping multiple header lines (Java) ([#17990](https://github.com/apache/beam/issues/17990)).
 * Python GCSIO is now implemented with GCP GCS Client instead of apitools ([#25676](https://github.com/apache/beam/issues/25676))
 * Adding support for LowCardinality DataType in ClickHouse (Java) ([#29533](https://github.com/apache/beam/pull/29533)).
+* Added support for handling bad records to KafkaIO (Java) ([#29546](https://github.com/apache/beam/pull/29546))
 
 ## New Features / Improvements
 
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/kafka/KafkaTestUtilities.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/kafka/KafkaTestUtilities.groovy
index cd2875fdb51..bb08e79edd3 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/kafka/KafkaTestUtilities.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/kafka/KafkaTestUtilities.groovy
@@ -40,7 +40,7 @@ class KafkaTestUtilities {
         '"keySizeBytes": "10",' +
         '"valueSizeBytes": "90"' +
         '}',
-        "--readTimeout=120",
+        "--readTimeout=60",
         "--kafkaTopic=beam",
         "--withTestcontainers=true",
         "--kafkaContainerVersion=5.5.2",
@@ -56,6 +56,7 @@ class KafkaTestUtilities {
           excludeTestsMatching "*SDFResumesCorrectly" //Kafka SDF does not work for kafka versions <2.0.1
           excludeTestsMatching "*StopReadingFunction" //Kafka SDF does not work for kafka versions <2.0.1
           excludeTestsMatching "*WatermarkUpdateWithSparseMessages" //Kafka SDF does not work for kafka versions <2.0.1
+          excludeTestsMatching "*KafkaIOSDFReadWithErrorHandler"
         }
       }
     }
diff --git a/examples/java/src/main/java/org/apache/beam/examples/KafkaStreaming.java b/examples/java/src/main/java/org/apache/beam/examples/KafkaStreaming.java
index 34a4b646555..602c34d4219 100644
--- a/examples/java/src/main/java/org/apache/beam/examples/KafkaStreaming.java
+++ b/examples/java/src/main/java/org/apache/beam/examples/KafkaStreaming.java
@@ -49,8 +49,11 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.transforms.errorhandling.BadRecord;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.BadRecordErrorHandler;
 import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
@@ -60,6 +63,8 @@ import org.apache.beam.sdk.transforms.windowing.Trigger;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.kafka.common.errors.SerializationException;
+import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.serialization.IntegerDeserializer;
 import org.apache.kafka.common.serialization.IntegerSerializer;
 import org.apache.kafka.common.serialization.StringDeserializer;
@@ -97,7 +102,7 @@ public class KafkaStreaming {
      * to use your own Kafka server.
      */
     @Description("Kafka server host")
-    @Default.String("kafka_server:9092")
+    @Default.String("localhost:9092")
     String getKafkaHost();
 
     void setKafkaHost(String value);
@@ -208,15 +213,22 @@ public class KafkaStreaming {
       // Start reading form Kafka with the latest offset
       consumerConfig.put("auto.offset.reset", "latest");
 
-      PCollection<KV<String, Integer>> pCollection =
-          pipeline.apply(
-              KafkaIO.<String, Integer>read()
-                  .withBootstrapServers(options.getKafkaHost())
-                  .withTopic(TOPIC_NAME)
-                  .withKeyDeserializer(StringDeserializer.class)
-                  .withValueDeserializer(IntegerDeserializer.class)
-                  .withConsumerConfigUpdates(consumerConfig)
-                  .withoutMetadata());
+      // Register an error handler for any deserialization errors.
+      // Errors are simulated with an intentionally failing deserializer
+      PCollection<KV<String, Integer>> pCollection;
+      try (BadRecordErrorHandler<PCollection<BadRecord>> errorHandler =
+          pipeline.registerBadRecordErrorHandler(new LogErrors())) {
+        pCollection =
+            pipeline.apply(
+                KafkaIO.<String, Integer>read()
+                    .withBootstrapServers(options.getKafkaHost())
+                    .withTopic(TOPIC_NAME)
+                    .withKeyDeserializer(StringDeserializer.class)
+                    .withValueDeserializer(IntermittentlyFailingIntegerDeserializer.class)
+                    .withConsumerConfigUpdates(consumerConfig)
+                    .withBadRecordErrorHandler(errorHandler)
+                    .withoutMetadata());
+      }
 
       pCollection
           // Apply a window and a trigger ourput repeatedly.
@@ -317,4 +329,39 @@ public class KafkaStreaming {
       c.output(c.element());
     }
   }
+
+  // Simple PTransform to log Error information
+  static class LogErrors extends PTransform<PCollection<BadRecord>, PCollection<BadRecord>> {
+
+    @Override
+    public PCollection<BadRecord> expand(PCollection<BadRecord> input) {
+      return input.apply("Log Errors", ParDo.of(new LogErrorFn()));
+    }
+
+    static class LogErrorFn extends DoFn<BadRecord, BadRecord> {
+      @ProcessElement
+      public void processElement(@Element BadRecord record, OutputReceiver<BadRecord> receiver) {
+        System.out.println(record);
+        receiver.output(record);
+      }
+    }
+  }
+
+  // Intentionally failing deserializer to simulate bad data from Kafka
+  public static class IntermittentlyFailingIntegerDeserializer implements Deserializer<Integer> {
+
+    public static final IntegerDeserializer INTEGER_DESERIALIZER = new IntegerDeserializer();
+    public int deserializeCount = 0;
+
+    public IntermittentlyFailingIntegerDeserializer() {}
+
+    @Override
+    public Integer deserialize(String topic, byte[] data) {
+      deserializeCount++;
+      if (deserializeCount % 10 == 0) {
+        throw new SerializationException("Expected Serialization Exception");
+      }
+      return INTEGER_DESERIALIZER.deserialize(topic, data);
+    }
+  }
 }
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandler.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandler.java
index 9e0298d885e..e02965b7202 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandler.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandler.java
@@ -17,6 +17,9 @@
  */
 package org.apache.beam.sdk.transforms.errorhandling;
 
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.Serializable;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.List;
@@ -49,22 +52,24 @@ import org.slf4j.LoggerFactory;
  *     <p>Simple usage with one DLQ
  *     <pre>{@code
  * PCollection<?> records = ...;
- * try (ErrorHandler<E,T> errorHandler = pipeline.registerErrorHandler(SomeSink.write())) {
- *  PCollection<?> results = records.apply(SomeIO.write().withDeadLetterQueue(errorHandler));
+ * try (BadRecordErrorHandler<T> errorHandler = pipeline.registerBadRecordErrorHandler(SomeSink.write())) {
+ *  PCollection<?> results = records.apply(SomeIO.write().withErrorHandler(errorHandler));
  * }
  * results.apply(SomeOtherTransform);
  * }</pre>
  *     Usage with multiple DLQ stages
  *     <pre>{@code
  * PCollection<?> records = ...;
- * try (ErrorHandler<E,T> errorHandler = pipeline.registerErrorHandler(SomeSink.write())) {
- *  PCollection<?> results = records.apply(SomeIO.write().withDeadLetterQueue(errorHandler))
- *                        .apply(OtherTransform.builder().withDeadLetterQueue(errorHandler));
+ * try (BadRecordErrorHandler<T> errorHandler = pipeline.registerBadRecordErrorHandler(SomeSink.write())) {
+ *  PCollection<?> results = records.apply(SomeIO.write().withErrorHandler(errorHandler))
+ *                        .apply(OtherTransform.builder().withErrorHandler(errorHandler));
  * }
  * results.apply(SomeOtherTransform);
  * }</pre>
+ *     This is marked as serializable despite never being needed on the runner, to enable it to be a
+ *     parameter of an Autovalue configured PTransform.
  */
-public interface ErrorHandler<ErrorT, OutputT extends POutput> extends AutoCloseable {
+public interface ErrorHandler<ErrorT, OutputT extends POutput> extends AutoCloseable, Serializable {
 
   void addErrorCollection(PCollection<ErrorT> errorCollection);
 
@@ -79,13 +84,16 @@ public interface ErrorHandler<ErrorT, OutputT extends POutput> extends AutoClose
     private static final Logger LOG = LoggerFactory.getLogger(PTransformErrorHandler.class);
     private final PTransform<PCollection<ErrorT>, OutputT> sinkTransform;
 
-    private final Pipeline pipeline;
+    // transient as Pipelines are not serializable
+    private final transient Pipeline pipeline;
 
     private final Coder<ErrorT> coder;
 
-    private final List<PCollection<ErrorT>> errorCollections = new ArrayList<>();
+    // transient as PCollections are not serializable
+    private transient List<PCollection<ErrorT>> errorCollections = new ArrayList<>();
 
-    private @Nullable OutputT sinkOutput = null;
+    // transient as PCollections are not serializable
+    private transient @Nullable OutputT sinkOutput = null;
 
     private boolean closed = false;
 
@@ -103,6 +111,12 @@ public interface ErrorHandler<ErrorT, OutputT extends POutput> extends AutoClose
       this.coder = coder;
     }
 
+    private void readObject(ObjectInputStream aInputStream)
+        throws ClassNotFoundException, IOException {
+      aInputStream.defaultReadObject();
+      errorCollections = new ArrayList<>();
+    }
+
     @Override
     public void addErrorCollection(PCollection<ErrorT> errorCollection) {
       errorCollections.add(errorCollection);
diff --git a/sdks/java/io/kafka/kafka-01103/build.gradle b/sdks/java/io/kafka/kafka-01103/build.gradle
index a0fa372397a..3a74bf04ef2 100644
--- a/sdks/java/io/kafka/kafka-01103/build.gradle
+++ b/sdks/java/io/kafka/kafka-01103/build.gradle
@@ -18,6 +18,7 @@
 project.ext {
     delimited="0.11.0.3"
     undelimited="01103"
+    sdfCompatible=false
 }
 
 apply from: "../kafka-integration-test.gradle"
\ No newline at end of file
diff --git a/sdks/java/io/kafka/kafka-100/build.gradle b/sdks/java/io/kafka/kafka-100/build.gradle
index 15ce8c0deef..bd5fa67b1cf 100644
--- a/sdks/java/io/kafka/kafka-100/build.gradle
+++ b/sdks/java/io/kafka/kafka-100/build.gradle
@@ -18,6 +18,7 @@
 project.ext {
     delimited="1.0.0"
     undelimited="100"
+    sdfCompatible=false
 }
 
-apply from: "../kafka-integration-test.gradle"
\ No newline at end of file
+apply from: "../kafka-integration-test.gradle"
diff --git a/sdks/java/io/kafka/kafka-111/build.gradle b/sdks/java/io/kafka/kafka-111/build.gradle
index fee4c382ed4..c2b0c8f8282 100644
--- a/sdks/java/io/kafka/kafka-111/build.gradle
+++ b/sdks/java/io/kafka/kafka-111/build.gradle
@@ -18,6 +18,7 @@
 project.ext {
     delimited="1.1.1"
     undelimited="111"
+    sdfCompatible=false
 }
 
 apply from: "../kafka-integration-test.gradle"
\ No newline at end of file
diff --git a/sdks/java/io/kafka/kafka-201/build.gradle b/sdks/java/io/kafka/kafka-201/build.gradle
index d395d0aa626..a26ca4ac19c 100644
--- a/sdks/java/io/kafka/kafka-201/build.gradle
+++ b/sdks/java/io/kafka/kafka-201/build.gradle
@@ -18,6 +18,7 @@
 project.ext {
     delimited="2.0.1"
     undelimited="201"
+    sdfCompatible=true
 }
 
 apply from: "../kafka-integration-test.gradle"
\ No newline at end of file
diff --git a/sdks/java/io/kafka/kafka-211/build.gradle b/sdks/java/io/kafka/kafka-211/build.gradle
index 4de07193b5a..433d6c93f36 100644
--- a/sdks/java/io/kafka/kafka-211/build.gradle
+++ b/sdks/java/io/kafka/kafka-211/build.gradle
@@ -18,6 +18,7 @@
 project.ext {
     delimited="2.1.1"
     undelimited="211"
+    sdfCompatible=true
 }
 
 apply from: "../kafka-integration-test.gradle"
\ No newline at end of file
diff --git a/sdks/java/io/kafka/kafka-222/build.gradle b/sdks/java/io/kafka/kafka-222/build.gradle
index 57de58e8189..0f037e74296 100644
--- a/sdks/java/io/kafka/kafka-222/build.gradle
+++ b/sdks/java/io/kafka/kafka-222/build.gradle
@@ -18,6 +18,7 @@
 project.ext {
     delimited="2.2.2"
     undelimited="222"
+    sdfCompatible=true
 }
 
 apply from: "../kafka-integration-test.gradle"
\ No newline at end of file
diff --git a/sdks/java/io/kafka/kafka-231/build.gradle b/sdks/java/io/kafka/kafka-231/build.gradle
index 3682791c5b6..712158dcd3a 100644
--- a/sdks/java/io/kafka/kafka-231/build.gradle
+++ b/sdks/java/io/kafka/kafka-231/build.gradle
@@ -18,6 +18,7 @@
 project.ext {
     delimited="2.3.1"
     undelimited="231"
+    sdfCompatible=true
 }
 
 apply from: "../kafka-integration-test.gradle"
\ No newline at end of file
diff --git a/sdks/java/io/kafka/kafka-241/build.gradle b/sdks/java/io/kafka/kafka-241/build.gradle
index 358c95aeb2f..c0ac7df674b 100644
--- a/sdks/java/io/kafka/kafka-241/build.gradle
+++ b/sdks/java/io/kafka/kafka-241/build.gradle
@@ -18,6 +18,7 @@
 project.ext {
     delimited="2.4.1"
     undelimited="241"
+    sdfCompatible=true
 }
 
 apply from: "../kafka-integration-test.gradle"
\ No newline at end of file
diff --git a/sdks/java/io/kafka/kafka-251/build.gradle b/sdks/java/io/kafka/kafka-251/build.gradle
index f291ecccc36..4de9f97a738 100644
--- a/sdks/java/io/kafka/kafka-251/build.gradle
+++ b/sdks/java/io/kafka/kafka-251/build.gradle
@@ -18,6 +18,7 @@
 project.ext {
     delimited="2.5.1"
     undelimited="251"
+    sdfCompatible=true
 }
 
 apply from: "../kafka-integration-test.gradle"
\ No newline at end of file
diff --git a/sdks/java/io/kafka/kafka-integration-test.gradle b/sdks/java/io/kafka/kafka-integration-test.gradle
index 778f8a3c456..1aeb0c97f93 100644
--- a/sdks/java/io/kafka/kafka-integration-test.gradle
+++ b/sdks/java/io/kafka/kafka-integration-test.gradle
@@ -39,4 +39,4 @@ dependencies {
 
 configurations.create("kafkaVersion$undelimited")
 
-tasks.register("kafkaVersion${undelimited}BatchIT",KafkaTestUtilities.KafkaBatchIT, project.ext.delimited, project.ext.undelimited, false, configurations, project)
\ No newline at end of file
+tasks.register("kafkaVersion${undelimited}BatchIT",KafkaTestUtilities.KafkaBatchIT, project.ext.delimited, project.ext.undelimited, project.ext.sdfCompatible, configurations, project)
\ No newline at end of file
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index 7e4fc55c6ce..8fd0c34cfa9 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -81,6 +81,11 @@ import org.apache.beam.sdk.transforms.Reshuffle;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.SimpleFunction;
 import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.errorhandling.BadRecord;
+import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.BadRecordErrorHandler;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.DefaultErrorHandler;
 import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
 import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.Manual;
 import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.MonotonicallyIncreasing;
@@ -89,9 +94,11 @@ import org.apache.beam.sdk.util.Preconditions;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.Row;
 import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
 import org.apache.beam.sdk.values.TypeDescriptor;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner;
@@ -167,6 +174,10 @@ import org.slf4j.LoggerFactory;
  *      // signal.
  *      .withCheckStopReadingFn(new SerializedFunction<TopicPartition, Boolean>() {})
  *
+ *      //If you would like to send messages that fail to be parsed from Kafka to an alternate sink,
+ *      //use the error handler pattern as defined in {@link ErrorHandler}
+ *      .withBadRecordErrorHandler(errorHandler)
+ *
  *      // finally, if you don't need Kafka metadata, you can drop it.g
  *      .withoutMetadata() // PCollection<KV<Long, String>>
  *   )
@@ -469,6 +480,11 @@ import org.slf4j.LoggerFactory;
  *      // or you can also set a custom timestamp with a function.
  *      .withPublishTimestampFunction((elem, elemTs) -> ...)
  *
+ *      // Optionally, records that fail to serialize can be sent to an error handler
+ *      // See {@link ErrorHandler} for details of for details of configuring a bad record error
+ *      // handler
+ *      .withBadRecordErrorHandler(errorHandler)
+ *
  *      // Optionally enable exactly-once sink (on supported runners). See JavaDoc for withEOS().
  *      .withEOS(20, "eos-sink-group-id");
  *   );
@@ -592,13 +608,7 @@ public class KafkaIO {
    */
   public static <K, V> Write<K, V> write() {
     return new AutoValue_KafkaIO_Write.Builder<K, V>()
-        .setWriteRecordsTransform(
-            new AutoValue_KafkaIO_WriteRecords.Builder<K, V>()
-                .setProducerConfig(WriteRecords.DEFAULT_PRODUCER_PROPERTIES)
-                .setEOS(false)
-                .setNumShards(0)
-                .setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN)
-                .build())
+        .setWriteRecordsTransform(writeRecords())
         .build();
   }
 
@@ -613,6 +623,8 @@ public class KafkaIO {
         .setEOS(false)
         .setNumShards(0)
         .setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN)
+        .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER)
+        .setBadRecordErrorHandler(new DefaultErrorHandler<>())
         .build();
   }
 
@@ -691,6 +703,9 @@ public class KafkaIO {
     @Pure
     public abstract @Nullable CheckStopReadingFn getCheckStopReadingFn();
 
+    @Pure
+    public abstract @Nullable ErrorHandler<BadRecord, ?> getBadRecordErrorHandler();
+
     abstract Builder<K, V> toBuilder();
 
     @AutoValue.Builder
@@ -739,6 +754,9 @@ public class KafkaIO {
 
       abstract Builder<K, V> setCheckStopReadingFn(@Nullable CheckStopReadingFn checkStopReadingFn);
 
+      abstract Builder<K, V> setBadRecordErrorHandler(
+          @Nullable ErrorHandler<BadRecord, ?> badRecordErrorHandler);
+
       Builder<K, V> setCheckStopReadingFn(
           @Nullable SerializableFunction<TopicPartition, Boolean> checkStopReadingFn) {
         return setCheckStopReadingFn(CheckStopReadingFnWrapper.of(checkStopReadingFn));
@@ -1312,6 +1330,10 @@ public class KafkaIO {
           .build();
     }
 
+    public Read<K, V> withBadRecordErrorHandler(ErrorHandler<BadRecord, ?> badRecordErrorHandler) {
+      return toBuilder().setBadRecordErrorHandler(badRecordErrorHandler).build();
+    }
+
     /** Returns a {@link PTransform} for PCollection of {@link KV}, dropping Kafka metatdata. */
     public PTransform<PBegin, PCollection<KV<K, V>>> withoutMetadata() {
       return new TypedWithoutMetadata<>(this);
@@ -1529,6 +1551,11 @@ public class KafkaIO {
 
       @Override
       public PCollection<KafkaRecord<K, V>> expand(PBegin input) {
+        if (kafkaRead.getBadRecordErrorHandler() != null) {
+          LOG.warn(
+              "The Legacy implementation of Kafka Read does not support writing malformed"
+                  + "messages to an error handler. Use the SDF implementation instead.");
+        }
         // Handles unbounded source to bounded conversion if maxNumRecords or maxReadTime is set.
         Unbounded<KafkaRecord<K, V>> unbounded =
             org.apache.beam.sdk.io.Read.from(
@@ -1576,6 +1603,10 @@ public class KafkaIO {
         if (kafkaRead.getStopReadTime() != null) {
           readTransform = readTransform.withBounded();
         }
+        if (kafkaRead.getBadRecordErrorHandler() != null) {
+          readTransform =
+              readTransform.withBadRecordErrorHandler(kafkaRead.getBadRecordErrorHandler());
+        }
         PCollection<KafkaSourceDescriptor> output;
         if (kafkaRead.isDynamicRead()) {
           Set<String> topics = new HashSet<>();
@@ -1956,6 +1987,8 @@ public class KafkaIO {
   public abstract static class ReadSourceDescriptors<K, V>
       extends PTransform<PCollection<KafkaSourceDescriptor>, PCollection<KafkaRecord<K, V>>> {
 
+    private final TupleTag<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> records = new TupleTag<>();
+
     private static final Logger LOG = LoggerFactory.getLogger(ReadSourceDescriptors.class);
 
     @Pure
@@ -1997,6 +2030,12 @@ public class KafkaIO {
     @Pure
     abstract @Nullable TimestampPolicyFactory<K, V> getTimestampPolicyFactory();
 
+    @Pure
+    abstract BadRecordRouter getBadRecordRouter();
+
+    @Pure
+    abstract ErrorHandler<BadRecord, ?> getBadRecordErrorHandler();
+
     abstract boolean isBounded();
 
     abstract ReadSourceDescriptors.Builder<K, V> toBuilder();
@@ -2041,6 +2080,12 @@ public class KafkaIO {
       abstract ReadSourceDescriptors.Builder<K, V> setTimestampPolicyFactory(
           TimestampPolicyFactory<K, V> policy);
 
+      abstract ReadSourceDescriptors.Builder<K, V> setBadRecordRouter(
+          BadRecordRouter badRecordRouter);
+
+      abstract ReadSourceDescriptors.Builder<K, V> setBadRecordErrorHandler(
+          ErrorHandler<BadRecord, ?> badRecordErrorHandler);
+
       abstract ReadSourceDescriptors.Builder<K, V> setBounded(boolean bounded);
 
       abstract ReadSourceDescriptors<K, V> build();
@@ -2052,6 +2097,8 @@ public class KafkaIO {
           .setConsumerConfig(KafkaIOUtils.DEFAULT_CONSUMER_PROPERTIES)
           .setCommitOffsetEnabled(false)
           .setBounded(false)
+          .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER)
+          .setBadRecordErrorHandler(new ErrorHandler.DefaultErrorHandler<>())
           .build()
           .withProcessingTime()
           .withMonotonicallyIncreasingWatermarkEstimator();
@@ -2305,6 +2352,14 @@ public class KafkaIO {
       return toBuilder().setConsumerConfig(consumerConfig).build();
     }
 
+    public ReadSourceDescriptors<K, V> withBadRecordErrorHandler(
+        ErrorHandler<BadRecord, ?> errorHandler) {
+      return toBuilder()
+          .setBadRecordRouter(BadRecordRouter.RECORDING_ROUTER)
+          .setBadRecordErrorHandler(errorHandler)
+          .build();
+    }
+
     ReadAllFromRow<K, V> forExternalBuild() {
       return new ReadAllFromRow<>(this);
     }
@@ -2395,9 +2450,18 @@ public class KafkaIO {
       Coder<KafkaRecord<K, V>> recordCoder = KafkaRecordCoder.of(keyCoder, valueCoder);
 
       try {
+        PCollectionTuple pCollectionTuple =
+            input.apply(
+                ParDo.of(ReadFromKafkaDoFn.<K, V>create(this, records))
+                    .withOutputTags(records, TupleTagList.of(BadRecordRouter.BAD_RECORD_TAG)));
+        getBadRecordErrorHandler()
+            .addErrorCollection(
+                pCollectionTuple
+                    .get(BadRecordRouter.BAD_RECORD_TAG)
+                    .setCoder(BadRecord.getCoder(input.getPipeline())));
         PCollection<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> outputWithDescriptor =
-            input
-                .apply(ParDo.of(ReadFromKafkaDoFn.<K, V>create(this)))
+            pCollectionTuple
+                .get(records)
                 .setCoder(
                     KvCoder.of(
                         input
@@ -2538,6 +2602,12 @@ public class KafkaIO {
     public abstract @Nullable SerializableFunction<Map<String, Object>, ? extends Consumer<?, ?>>
         getConsumerFactoryFn();
 
+    @Pure
+    public abstract BadRecordRouter getBadRecordRouter();
+
+    @Pure
+    public abstract ErrorHandler<BadRecord, ?> getBadRecordErrorHandler();
+
     abstract Builder<K, V> toBuilder();
 
     @AutoValue.Builder
@@ -2565,6 +2635,11 @@ public class KafkaIO {
       abstract Builder<K, V> setConsumerFactoryFn(
           SerializableFunction<Map<String, Object>, ? extends Consumer<?, ?>> fn);
 
+      abstract Builder<K, V> setBadRecordRouter(BadRecordRouter router);
+
+      abstract Builder<K, V> setBadRecordErrorHandler(
+          ErrorHandler<BadRecord, ?> badRecordErrorHandler);
+
       abstract WriteRecords<K, V> build();
     }
 
@@ -2711,6 +2786,14 @@ public class KafkaIO {
       return toBuilder().setConsumerFactoryFn(consumerFactoryFn).build();
     }
 
+    public WriteRecords<K, V> withBadRecordErrorHandler(
+        ErrorHandler<BadRecord, ?> badRecordErrorHandler) {
+      return toBuilder()
+          .setBadRecordRouter(BadRecordRouter.RECORDING_ROUTER)
+          .setBadRecordErrorHandler(badRecordErrorHandler)
+          .build();
+    }
+
     @Override
     public PDone expand(PCollection<ProducerRecord<K, V>> input) {
       checkArgument(
@@ -2722,6 +2805,9 @@ public class KafkaIO {
 
       if (isEOS()) {
         checkArgument(getTopic() != null, "withTopic() is required when isEOS() is true");
+        checkArgument(
+            getBadRecordErrorHandler() instanceof DefaultErrorHandler,
+            "BadRecordErrorHandling isn't supported with Kafka Exactly Once writing");
         KafkaExactlyOnceSink.ensureEOSSupport();
 
         // TODO: Verify that the group_id does not have existing state stored on Kafka unless
@@ -2732,7 +2818,19 @@ public class KafkaIO {
 
         input.apply(new KafkaExactlyOnceSink<>(this));
       } else {
-        input.apply(ParDo.of(new KafkaWriter<>(this)));
+        // Even though the errors are the only output from writing to Kafka, we maintain a
+        // PCollectionTuple
+        // with a void tag as the 'primary' output for easy forward compatibility
+        PCollectionTuple pCollectionTuple =
+            input.apply(
+                ParDo.of(new KafkaWriter<>(this))
+                    .withOutputTags(
+                        new TupleTag<Void>(), TupleTagList.of(BadRecordRouter.BAD_RECORD_TAG)));
+        getBadRecordErrorHandler()
+            .addErrorCollection(
+                pCollectionTuple
+                    .get(BadRecordRouter.BAD_RECORD_TAG)
+                    .setCoder(BadRecord.getCoder(input.getPipeline())));
       }
       return PDone.in(input.getPipeline());
     }
@@ -2995,6 +3093,15 @@ public class KafkaIO {
           getWriteRecordsTransform().withProducerConfigUpdates(configUpdates));
     }
 
+    /**
+     * Configure a {@link BadRecordErrorHandler} for sending records to if they fail to serialize
+     * when being sent to Kafka.
+     */
+    public Write<K, V> withBadRecordErrorHandler(ErrorHandler<BadRecord, ?> badRecordErrorHandler) {
+      return withWriteRecordsTransform(
+          getWriteRecordsTransform().withBadRecordErrorHandler(badRecordErrorHandler));
+    }
+
     @Override
     public PDone expand(PCollection<KV<K, V>> input) {
       final String topic = Preconditions.checkStateNotNull(getTopic(), "withTopic() is required");
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java
index b779de1d9cf..a2cc9aaeb4d 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java
@@ -111,6 +111,7 @@ class KafkaIOReadImplementationCompatibility {
     KEY_DESERIALIZER_PROVIDER,
     VALUE_DESERIALIZER_PROVIDER,
     CHECK_STOP_READING_FN(SDF),
+    BAD_RECORD_ERROR_HANDLER(SDF),
     ;
 
     @Nonnull private final ImmutableSet<KafkaIOReadImplementation> supportedImplementations;
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriter.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriter.java
index c0c9772959f..4f4663aa8cc 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriter.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriter.java
@@ -25,6 +25,7 @@ import org.apache.beam.sdk.io.kafka.KafkaIO.WriteRecords;
 import org.apache.beam.sdk.metrics.Counter;
 import org.apache.beam.sdk.metrics.SinkMetrics;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter;
 import org.apache.beam.sdk.util.Preconditions;
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.KafkaProducer;
@@ -32,6 +33,7 @@ import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.clients.producer.ProducerRecord;
 import org.apache.kafka.clients.producer.RecordMetadata;
+import org.apache.kafka.common.errors.SerializationException;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -57,7 +59,7 @@ class KafkaWriter<K, V> extends DoFn<ProducerRecord<K, V>, Void> {
   // Suppression since errors are tracked in SendCallback(), and checked in finishBundle()
   @ProcessElement
   @SuppressWarnings("FutureReturnValueIgnored")
-  public void processElement(ProcessContext ctx) throws Exception {
+  public void processElement(ProcessContext ctx, MultiOutputReceiver receiver) throws Exception {
     Producer<K, V> producer = Preconditions.checkStateNotNull(this.producer);
     checkForFailures();
 
@@ -75,19 +77,31 @@ class KafkaWriter<K, V> extends DoFn<ProducerRecord<K, V>, Void> {
       topicName = spec.getTopic();
     }
 
-    @SuppressWarnings({"nullness", "unused"}) // Kafka library not annotated
-    Future<RecordMetadata> ignored =
-        producer.send(
-            new ProducerRecord<>(
-                topicName,
-                record.partition(),
-                timestampMillis,
-                record.key(),
-                record.value(),
-                record.headers()),
-            callback);
-
-    elementsWritten.inc();
+    try {
+      @SuppressWarnings({"nullness", "unused"}) // Kafka library not annotated
+      Future<RecordMetadata> ignored =
+          producer.send(
+              new ProducerRecord<>(
+                  topicName,
+                  record.partition(),
+                  timestampMillis,
+                  record.key(),
+                  record.value(),
+                  record.headers()),
+              callback);
+
+      elementsWritten.inc();
+    } catch (SerializationException e) {
+      // This exception should only occur during the key and value deserialization when
+      // creating the Kafka Record. We can catch the exception here as producer.send serializes
+      // the record before starting the future.
+      badRecordRouter.route(
+          receiver,
+          record,
+          null,
+          e,
+          "Failure serializing Key or Value of Kakfa record writing from Kafka");
+    }
   }
 
   @FinishBundle
@@ -110,6 +124,8 @@ class KafkaWriter<K, V> extends DoFn<ProducerRecord<K, V>, Void> {
   private final WriteRecords<K, V> spec;
   private final Map<String, Object> producerConfig;
 
+  private final BadRecordRouter badRecordRouter;
+
   private transient @Nullable Producer<K, V> producer = null;
   // first exception and number of failures since last invocation of checkForFailures():
   private transient @Nullable Exception sendException = null;
@@ -122,6 +138,8 @@ class KafkaWriter<K, V> extends DoFn<ProducerRecord<K, V>, Void> {
 
     this.producerConfig = new HashMap<>(spec.getProducerConfig());
 
+    this.badRecordRouter = spec.getBadRecordRouter();
+
     if (spec.getKeySerializer() != null) {
       this.producerConfig.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, spec.getKeySerializer());
     }
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
index 1b6e3addce2..924833290f1 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
@@ -35,6 +35,7 @@ import org.apache.beam.sdk.metrics.Distribution;
 import org.apache.beam.sdk.metrics.Metrics;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter;
 import org.apache.beam.sdk.transforms.splittabledofn.GrowableOffsetRangeTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
 import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker;
@@ -45,6 +46,7 @@ import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.Monoton
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.Preconditions;
 import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier;
@@ -60,6 +62,7 @@ import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.ConsumerRecords;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.SerializationException;
 import org.apache.kafka.common.serialization.Deserializer;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Instant;
@@ -144,29 +147,37 @@ import org.slf4j.LoggerFactory;
 abstract class ReadFromKafkaDoFn<K, V>
     extends DoFn<KafkaSourceDescriptor, KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> {
 
-  static <K, V> ReadFromKafkaDoFn<K, V> create(ReadSourceDescriptors<K, V> transform) {
+  static <K, V> ReadFromKafkaDoFn<K, V> create(
+      ReadSourceDescriptors<K, V> transform,
+      TupleTag<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> recordTag) {
     if (transform.isBounded()) {
-      return new Bounded<>(transform);
+      return new Bounded<>(transform, recordTag);
     } else {
-      return new Unbounded<>(transform);
+      return new Unbounded<>(transform, recordTag);
     }
   }
 
   @UnboundedPerElement
   private static class Unbounded<K, V> extends ReadFromKafkaDoFn<K, V> {
-    Unbounded(ReadSourceDescriptors<K, V> transform) {
-      super(transform);
+    Unbounded(
+        ReadSourceDescriptors<K, V> transform,
+        TupleTag<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> recordTag) {
+      super(transform, recordTag);
     }
   }
 
   @BoundedPerElement
   private static class Bounded<K, V> extends ReadFromKafkaDoFn<K, V> {
-    Bounded(ReadSourceDescriptors<K, V> transform) {
-      super(transform);
+    Bounded(
+        ReadSourceDescriptors<K, V> transform,
+        TupleTag<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> recordTag) {
+      super(transform, recordTag);
     }
   }
 
-  private ReadFromKafkaDoFn(ReadSourceDescriptors<K, V> transform) {
+  private ReadFromKafkaDoFn(
+      ReadSourceDescriptors<K, V> transform,
+      TupleTag<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> recordTag) {
     this.consumerConfig = transform.getConsumerConfig();
     this.offsetConsumerConfig = transform.getOffsetConsumerConfig();
     this.keyDeserializerProvider =
@@ -178,6 +189,8 @@ abstract class ReadFromKafkaDoFn<K, V>
     this.createWatermarkEstimatorFn = transform.getCreateWatermarkEstimatorFn();
     this.timestampPolicyFactory = transform.getTimestampPolicyFactory();
     this.checkStopReadingFn = transform.getCheckStopReadingFn();
+    this.badRecordRouter = transform.getBadRecordRouter();
+    this.recordTag = recordTag;
   }
 
   private static final Logger LOG = LoggerFactory.getLogger(ReadFromKafkaDoFn.class);
@@ -193,6 +206,10 @@ abstract class ReadFromKafkaDoFn<K, V>
       createWatermarkEstimatorFn;
   private final @Nullable TimestampPolicyFactory<K, V> timestampPolicyFactory;
 
+  private final BadRecordRouter badRecordRouter;
+
+  private final TupleTag<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> recordTag;
+
   // Valid between bundle start and bundle finish.
   private transient @Nullable Deserializer<K> keyDeserializerInstance = null;
   private transient @Nullable Deserializer<V> valueDeserializerInstance = null;
@@ -361,7 +378,8 @@ abstract class ReadFromKafkaDoFn<K, V>
       @Element KafkaSourceDescriptor kafkaSourceDescriptor,
       RestrictionTracker<OffsetRange, Long> tracker,
       WatermarkEstimator<Instant> watermarkEstimator,
-      OutputReceiver<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> receiver) {
+      MultiOutputReceiver receiver)
+      throws Exception {
     final LoadingCache<TopicPartition, AverageRecordSize> avgRecordSize =
         Preconditions.checkStateNotNull(this.avgRecordSize);
     final Deserializer<K> keyDeserializerInstance =
@@ -431,36 +449,52 @@ abstract class ReadFromKafkaDoFn<K, V>
           if (!tracker.tryClaim(rawRecord.offset())) {
             return ProcessContinuation.stop();
           }
-          KafkaRecord<K, V> kafkaRecord =
-              new KafkaRecord<>(
-                  rawRecord.topic(),
-                  rawRecord.partition(),
-                  rawRecord.offset(),
-                  ConsumerSpEL.getRecordTimestamp(rawRecord),
-                  ConsumerSpEL.getRecordTimestampType(rawRecord),
-                  ConsumerSpEL.hasHeaders() ? rawRecord.headers() : null,
-                  ConsumerSpEL.deserializeKey(keyDeserializerInstance, rawRecord),
-                  ConsumerSpEL.deserializeValue(valueDeserializerInstance, rawRecord));
-          int recordSize =
-              (rawRecord.key() == null ? 0 : rawRecord.key().length)
-                  + (rawRecord.value() == null ? 0 : rawRecord.value().length);
-          avgRecordSize
-              .getUnchecked(kafkaSourceDescriptor.getTopicPartition())
-              .update(recordSize, rawRecord.offset() - expectedOffset);
-          rawSizes.update(recordSize);
-          expectedOffset = rawRecord.offset() + 1;
-          Instant outputTimestamp;
-          // The outputTimestamp and watermark will be computed by timestampPolicy, where the
-          // WatermarkEstimator should be a manual one.
-          if (timestampPolicy != null) {
-            TimestampPolicyContext context =
-                updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker);
-            outputTimestamp = timestampPolicy.getTimestampForRecord(context, kafkaRecord);
-          } else {
-            Preconditions.checkStateNotNull(this.extractOutputTimestampFn);
-            outputTimestamp = extractOutputTimestampFn.apply(kafkaRecord);
+          try {
+            KafkaRecord<K, V> kafkaRecord =
+                new KafkaRecord<>(
+                    rawRecord.topic(),
+                    rawRecord.partition(),
+                    rawRecord.offset(),
+                    ConsumerSpEL.getRecordTimestamp(rawRecord),
+                    ConsumerSpEL.getRecordTimestampType(rawRecord),
+                    ConsumerSpEL.hasHeaders() ? rawRecord.headers() : null,
+                    ConsumerSpEL.deserializeKey(keyDeserializerInstance, rawRecord),
+                    ConsumerSpEL.deserializeValue(valueDeserializerInstance, rawRecord));
+            int recordSize =
+                (rawRecord.key() == null ? 0 : rawRecord.key().length)
+                    + (rawRecord.value() == null ? 0 : rawRecord.value().length);
+            avgRecordSize
+                .getUnchecked(kafkaSourceDescriptor.getTopicPartition())
+                .update(recordSize, rawRecord.offset() - expectedOffset);
+            rawSizes.update(recordSize);
+            expectedOffset = rawRecord.offset() + 1;
+            Instant outputTimestamp;
+            // The outputTimestamp and watermark will be computed by timestampPolicy, where the
+            // WatermarkEstimator should be a manual one.
+            if (timestampPolicy != null) {
+              TimestampPolicyContext context =
+                  updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker);
+              outputTimestamp = timestampPolicy.getTimestampForRecord(context, kafkaRecord);
+            } else {
+              Preconditions.checkStateNotNull(this.extractOutputTimestampFn);
+              outputTimestamp = extractOutputTimestampFn.apply(kafkaRecord);
+            }
+            receiver
+                .get(recordTag)
+                .outputWithTimestamp(KV.of(kafkaSourceDescriptor, kafkaRecord), outputTimestamp);
+          } catch (SerializationException e) {
+            // This exception should only occur during the key and value deserialization when
+            // creating the Kafka Record
+            badRecordRouter.route(
+                receiver,
+                rawRecord,
+                null,
+                e,
+                "Failure deserializing Key or Value of Kakfa record reading from Kafka");
+            if (timestampPolicy != null) {
+              updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker);
+            }
           }
-          receiver.outputWithTimestamp(KV.of(kafkaSourceDescriptor, kafkaRecord), outputTimestamp);
         }
       }
     }
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
index 2ccf7dcc3a9..38bf723a15a 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
@@ -350,13 +350,7 @@ public class KafkaIOExternalTest {
     RunnerApi.PTransform writeComposite =
         result.getComponents().getTransformsOrThrow(transform.getSubtransforms(1));
     RunnerApi.PTransform writeParDo =
-        result
-            .getComponents()
-            .getTransformsOrThrow(
-                result
-                    .getComponents()
-                    .getTransformsOrThrow(writeComposite.getSubtransforms(0))
-                    .getSubtransforms(0));
+        result.getComponents().getTransformsOrThrow(writeComposite.getSubtransforms(0));
 
     RunnerApi.ParDoPayload parDoPayload =
         RunnerApi.ParDoPayload.parseFrom(writeParDo.getSpec().getPayload());
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java
index 2c8ace9c66c..5b976687f2c 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java
@@ -29,6 +29,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Random;
 import java.util.Set;
 import java.util.UUID;
@@ -43,6 +44,9 @@ import org.apache.beam.sdk.io.GenerateSequence;
 import org.apache.beam.sdk.io.Read;
 import org.apache.beam.sdk.io.common.IOITHelper;
 import org.apache.beam.sdk.io.common.IOTestPipelineOptions;
+import org.apache.beam.sdk.io.kafka.KafkaIOTest.ErrorSinkTransform;
+import org.apache.beam.sdk.io.kafka.KafkaIOTest.FailingLongSerializer;
+import org.apache.beam.sdk.io.kafka.ReadFromKafkaDoFnTest.FailingDeserializer;
 import org.apache.beam.sdk.io.synthetic.SyntheticBoundedSource;
 import org.apache.beam.sdk.io.synthetic.SyntheticSourceOptions;
 import org.apache.beam.sdk.options.Default;
@@ -72,6 +76,7 @@ import org.apache.beam.sdk.transforms.Keys;
 import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.BadRecordErrorHandler;
 import org.apache.beam.sdk.transforms.windowing.CalendarWindows;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
@@ -124,8 +129,6 @@ public class KafkaIOIT {
 
   private static final String RUN_TIME_METRIC_NAME = "run_time";
 
-  private static final String READ_ELEMENT_METRIC_NAME = "kafka_read_element_count";
-
   private static final String NAMESPACE = KafkaIOIT.class.getName();
 
   private static final String TEST_ID = UUID.randomUUID().toString();
@@ -352,6 +355,68 @@ public class KafkaIOIT {
     }
   }
 
+  // This test verifies that bad data from Kafka is properly sent to the error handler
+  @Test
+  public void testKafkaIOSDFReadWithErrorHandler() throws IOException {
+    writePipeline
+        .apply(Create.of(KV.of("key", "val")))
+        .apply(
+            "Write to Kafka",
+            KafkaIO.<String, String>write()
+                .withBootstrapServers(options.getKafkaBootstrapServerAddresses())
+                .withKeySerializer(StringSerializer.class)
+                .withValueSerializer(StringSerializer.class)
+                .withTopic(options.getKafkaTopic() + "-failingDeserialization"));
+
+    PipelineResult writeResult = writePipeline.run();
+    PipelineResult.State writeState = writeResult.waitUntilFinish();
+    assertNotEquals(PipelineResult.State.FAILED, writeState);
+
+    BadRecordErrorHandler<PCollection<Long>> eh =
+        sdfReadPipeline.registerBadRecordErrorHandler(new ErrorSinkTransform());
+    sdfReadPipeline.apply(
+        KafkaIO.<String, String>read()
+            .withBootstrapServers(options.getKafkaBootstrapServerAddresses())
+            .withTopic(options.getKafkaTopic() + "-failingDeserialization")
+            .withConsumerConfigUpdates(ImmutableMap.of("auto.offset.reset", "earliest"))
+            .withKeyDeserializer(FailingDeserializer.class)
+            .withValueDeserializer(FailingDeserializer.class)
+            .withBadRecordErrorHandler(eh));
+    eh.close();
+
+    PAssert.thatSingleton(Objects.requireNonNull(eh.getOutput())).isEqualTo(1L);
+
+    PipelineResult readResult = sdfReadPipeline.run();
+    PipelineResult.State readState =
+        readResult.waitUntilFinish(Duration.standardSeconds(options.getReadTimeout()));
+    cancelIfTimeouted(readResult, readState);
+    assertNotEquals(PipelineResult.State.FAILED, readState);
+  }
+
+  @Test
+  public void testKafkaIOWriteWithErrorHandler() throws IOException {
+
+    BadRecordErrorHandler<PCollection<Long>> eh =
+        writePipeline.registerBadRecordErrorHandler(new ErrorSinkTransform());
+    writePipeline
+        .apply("Create single KV", Create.of(KV.of("key", 4L)))
+        .apply(
+            "Write to Kafka",
+            KafkaIO.<String, Long>write()
+                .withBootstrapServers(options.getKafkaBootstrapServerAddresses())
+                .withKeySerializer(StringSerializer.class)
+                .withValueSerializer(FailingLongSerializer.class)
+                .withTopic(options.getKafkaTopic() + "-failingSerialization")
+                .withBadRecordErrorHandler(eh));
+    eh.close();
+
+    PAssert.thatSingleton(Objects.requireNonNull(eh.getOutput())).isEqualTo(1L);
+
+    PipelineResult writeResult = writePipeline.run();
+    PipelineResult.State writeState = writeResult.waitUntilFinish();
+    assertNotEquals(PipelineResult.State.FAILED, writeState);
+  }
+
   // This test roundtrips a single KV<Null,Null> to verify that externalWithMetadata
   // can handle null keys and values correctly.
   @Test
@@ -484,9 +549,7 @@ public class KafkaIOIT {
   public void testKafkaWithStopReadingFunction() {
     AlwaysStopCheckStopReadingFn checkStopReadingFn = new AlwaysStopCheckStopReadingFn();
 
-    PipelineResult readResult = runWithStopReadingFn(checkStopReadingFn, "stop-reading");
-
-    assertEquals(-1, readElementMetric(readResult, NAMESPACE, READ_ELEMENT_METRIC_NAME));
+    runWithStopReadingFn(checkStopReadingFn, "stop-reading", 0L);
   }
 
   private static class AlwaysStopCheckStopReadingFn implements CheckStopReadingFn {
@@ -500,11 +563,7 @@ public class KafkaIOIT {
   public void testKafkaWithDelayedStopReadingFunction() {
     DelayedCheckStopReadingFn checkStopReadingFn = new DelayedCheckStopReadingFn();
 
-    PipelineResult readResult = runWithStopReadingFn(checkStopReadingFn, "delayed-stop-reading");
-
-    assertEquals(
-        sourceOptions.numRecords,
-        readElementMetric(readResult, NAMESPACE, READ_ELEMENT_METRIC_NAME));
+    runWithStopReadingFn(checkStopReadingFn, "delayed-stop-reading", sourceOptions.numRecords);
   }
 
   public static final Schema KAFKA_TOPIC_SCHEMA =
@@ -644,7 +703,7 @@ public class KafkaIOIT {
 
     @Override
     public Boolean apply(TopicPartition input) {
-      if (checkCount >= 5) {
+      if (checkCount >= 10) {
         return true;
       }
       checkCount++;
@@ -652,7 +711,8 @@ public class KafkaIOIT {
     }
   }
 
-  private PipelineResult runWithStopReadingFn(CheckStopReadingFn function, String topicSuffix) {
+  private void runWithStopReadingFn(
+      CheckStopReadingFn function, String topicSuffix, Long expectedCount) {
     writePipeline
         .apply("Generate records", Read.from(new SyntheticBoundedSource(sourceOptions)))
         .apply("Measure write time", ParDo.of(new TimeMonitor<>(NAMESPACE, WRITE_TIME_METRIC_NAME)))
@@ -661,21 +721,31 @@ public class KafkaIOIT {
             writeToKafka().withTopic(options.getKafkaTopic() + "-" + topicSuffix));
 
     readPipeline.getOptions().as(Options.class).setStreaming(true);
-    readPipeline
-        .apply(
-            "Read from unbounded Kafka",
-            readFromKafka()
-                .withTopic(options.getKafkaTopic() + "-" + topicSuffix)
-                .withCheckStopReadingFn(function))
-        .apply("Measure read time", ParDo.of(new TimeMonitor<>(NAMESPACE, READ_TIME_METRIC_NAME)));
+    PCollection<Long> count =
+        readPipeline
+            .apply(
+                "Read from unbounded Kafka",
+                readFromKafka()
+                    .withTopic(options.getKafkaTopic() + "-" + topicSuffix)
+                    .withCheckStopReadingFn(function))
+            .apply(
+                "Measure read time", ParDo.of(new TimeMonitor<>(NAMESPACE, READ_TIME_METRIC_NAME)))
+            .apply("Window", Window.into(CalendarWindows.years(1)))
+            .apply(
+                "Counting element",
+                Combine.globally(Count.<KafkaRecord<byte[], byte[]>>combineFn()).withoutDefaults());
+
+    if (expectedCount == 0L) {
+      PAssert.that(count).empty();
+    } else {
+      PAssert.thatSingleton(count).isEqualTo(expectedCount);
+    }
 
     PipelineResult writeResult = writePipeline.run();
     writeResult.waitUntilFinish();
 
     PipelineResult readResult = readPipeline.run();
     readResult.waitUntilFinish(Duration.standardSeconds(options.getReadTimeout()));
-
-    return readResult;
   }
 
   @Test
@@ -686,7 +756,7 @@ public class KafkaIOIT {
 
     String topicName = "SparseDataTopicPartition-" + UUID.randomUUID();
     Map<Integer, String> records = new HashMap<>();
-    for (int i = 0; i < 5; i++) {
+    for (int i = 1; i <= 5; i++) {
       records.put(i, String.valueOf(i));
     }
 
@@ -725,7 +795,7 @@ public class KafkaIOIT {
 
       PipelineResult readResult = sdfReadPipeline.run();
 
-      Thread.sleep(options.getReadTimeout() * 1000);
+      Thread.sleep(options.getReadTimeout() * 1000 * 2);
 
       for (String value : records.values()) {
         kafkaIOITExpectedLogs.verifyError(value);
@@ -753,11 +823,6 @@ public class KafkaIOIT {
     }
   }
 
-  private long readElementMetric(PipelineResult result, String namespace, String name) {
-    MetricsReader metricsReader = new MetricsReader(result, namespace);
-    return metricsReader.getCounterMetric(name);
-  }
-
   private Set<NamedTestResult> readMetrics(PipelineResult writeResult, PipelineResult readResult) {
     BiFunction<MetricsReader, String, NamedTestResult> supplier =
         (reader, metricName) -> {
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
index aeb5818e913..b0df82bcdc1 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
@@ -51,6 +51,7 @@ import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Optional;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
@@ -87,6 +88,7 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.ExpectedLogs;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.Distinct;
@@ -95,11 +97,15 @@ import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.transforms.Max;
 import org.apache.beam.sdk.transforms.Min;
+import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.errorhandling.BadRecord;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.BadRecordErrorHandler;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.CalendarWindows;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.util.CoderUtils;
@@ -121,9 +127,12 @@ import org.apache.kafka.clients.producer.MockProducer;
 import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.kafka.clients.producer.internals.DefaultPartitioner;
+import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.SerializationException;
 import org.apache.kafka.common.header.Header;
 import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.header.internals.RecordHeader;
@@ -136,7 +145,10 @@ import org.apache.kafka.common.serialization.LongDeserializer;
 import org.apache.kafka.common.serialization.LongSerializer;
 import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.common.utils.Utils;
+import org.checkerframework.checker.initialization.qual.Initialized;
+import org.checkerframework.checker.nullness.qual.NonNull;
 import org.checkerframework.checker.nullness.qual.Nullable;
+import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
 import org.hamcrest.collection.IsIterableContainingInAnyOrder;
 import org.hamcrest.collection.IsIterableWithSize;
 import org.joda.time.Duration;
@@ -1379,7 +1391,7 @@ public class KafkaIOTest {
 
     int numElements = 1000;
 
-    try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) {
+    try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) {
 
       ProducerSendCompletionThread completionThread =
           new ProducerSendCompletionThread(producerWrapper.mockProducer).start();
@@ -1404,13 +1416,81 @@ public class KafkaIOTest {
     }
   }
 
+  public static class FailingLongSerializer implements Serializer<Long> {
+    // enables instantiation by registrys
+    public FailingLongSerializer() {}
+
+    @Override
+    public byte[] serialize(String topic, Long data) {
+      throw new SerializationException("ExpectedSerializationException");
+    }
+
+    @Override
+    public void configure(Map<String, ?> configs, boolean isKey) {
+      // intentionally left blank for compatibility with older kafka versions
+    }
+  }
+
+  @Test
+  public void testSinkWithSerializationErrors() throws Exception {
+    // Attempt to write 10 elements to Kafka, but they will all fail to serialize, and be sent to
+    // the DLQ
+
+    int numElements = 10;
+
+    try (MockProducerWrapper producerWrapper =
+        new MockProducerWrapper(new FailingLongSerializer())) {
+
+      ProducerSendCompletionThread completionThread =
+          new ProducerSendCompletionThread(producerWrapper.mockProducer).start();
+
+      String topic = "test";
+
+      BadRecordErrorHandler<PCollection<Long>> eh =
+          p.registerBadRecordErrorHandler(new ErrorSinkTransform());
+
+      p.apply(mkKafkaReadTransform(numElements, new ValueAsTimestampFn()).withoutMetadata())
+          .apply(
+              KafkaIO.<Integer, Long>write()
+                  .withBootstrapServers("none")
+                  .withTopic(topic)
+                  .withKeySerializer(IntegerSerializer.class)
+                  .withValueSerializer(FailingLongSerializer.class)
+                  .withInputTimestamp()
+                  .withProducerFactoryFn(new ProducerFactoryFn(producerWrapper.producerKey))
+                  .withBadRecordErrorHandler(eh));
+
+      eh.close();
+
+      PAssert.thatSingleton(Objects.requireNonNull(eh.getOutput())).isEqualTo(10L);
+
+      p.run();
+
+      completionThread.shutdown();
+
+      verifyProducerRecords(producerWrapper.mockProducer, topic, 0, false, true);
+    }
+  }
+
+  public static class ErrorSinkTransform
+      extends PTransform<PCollection<BadRecord>, PCollection<Long>> {
+
+    @Override
+    public @UnknownKeyFor @NonNull @Initialized PCollection<Long> expand(
+        PCollection<BadRecord> input) {
+      return input
+          .apply("Window", Window.into(CalendarWindows.years(1)))
+          .apply("Combine", Combine.globally(Count.<BadRecord>combineFn()).withoutDefaults());
+    }
+  }
+
   @Test
   public void testValuesSink() throws Exception {
     // similar to testSink(), but use values()' interface.
 
     int numElements = 1000;
 
-    try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) {
+    try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) {
 
       ProducerSendCompletionThread completionThread =
           new ProducerSendCompletionThread(producerWrapper.mockProducer).start();
@@ -1442,7 +1522,7 @@ public class KafkaIOTest {
 
     int numElements = 1000;
 
-    try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) {
+    try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) {
 
       ProducerSendCompletionThread completionThread =
           new ProducerSendCompletionThread(producerWrapper.mockProducer).start();
@@ -1474,7 +1554,7 @@ public class KafkaIOTest {
     // Set different output topic names
     int numElements = 1000;
 
-    try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) {
+    try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) {
 
       ProducerSendCompletionThread completionThread =
           new ProducerSendCompletionThread(producerWrapper.mockProducer).start();
@@ -1519,7 +1599,7 @@ public class KafkaIOTest {
     // Set different output topic names
     int numElements = 1;
     SimpleEntry<String, String> header = new SimpleEntry<>("header_key", "header_value");
-    try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) {
+    try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) {
 
       ProducerSendCompletionThread completionThread =
           new ProducerSendCompletionThread(producerWrapper.mockProducer).start();
@@ -1562,7 +1642,7 @@ public class KafkaIOTest {
   public void testSinkProducerRecordsWithCustomTS() throws Exception {
     int numElements = 1000;
 
-    try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) {
+    try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) {
 
       ProducerSendCompletionThread completionThread =
           new ProducerSendCompletionThread(producerWrapper.mockProducer).start();
@@ -1601,7 +1681,7 @@ public class KafkaIOTest {
   public void testSinkProducerRecordsWithCustomPartition() throws Exception {
     int numElements = 1000;
 
-    try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) {
+    try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) {
 
       ProducerSendCompletionThread completionThread =
           new ProducerSendCompletionThread(producerWrapper.mockProducer).start();
@@ -1725,7 +1805,7 @@ public class KafkaIOTest {
 
     int numElements = 1000;
 
-    try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) {
+    try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) {
 
       ProducerSendCompletionThread completionThread =
           new ProducerSendCompletionThread(producerWrapper.mockProducer).start();
@@ -1803,7 +1883,7 @@ public class KafkaIOTest {
 
     int numElements = 1000;
 
-    try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) {
+    try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) {
 
       ProducerSendCompletionThread completionThreadWithErrors =
           new ProducerSendCompletionThread(producerWrapper.mockProducer, 10, 100).start();
@@ -1993,7 +2073,7 @@ public class KafkaIOTest {
 
   @Test
   public void testSinkDisplayData() {
-    try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) {
+    try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) {
       KafkaIO.Write<Integer, Long> write =
           KafkaIO.<Integer, Long>write()
               .withBootstrapServers("myServerA:9092,myServerB:9092")
@@ -2017,7 +2097,7 @@ public class KafkaIOTest {
 
     int numElements = 1000;
 
-    try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) {
+    try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) {
 
       ProducerSendCompletionThread completionThread =
           new ProducerSendCompletionThread(producerWrapper.mockProducer).start();
@@ -2109,14 +2189,22 @@ public class KafkaIOTest {
       }
     }
 
-    MockProducerWrapper() {
+    MockProducerWrapper(Serializer<Long> valueSerializer) {
       producerKey = String.valueOf(ThreadLocalRandom.current().nextLong());
       mockProducer =
           new MockProducer<Integer, Long>(
+              Cluster.empty()
+                  .withPartitions(
+                      ImmutableMap.of(
+                          new TopicPartition("test", 0),
+                          new PartitionInfo("test", 0, null, null, null),
+                          new TopicPartition("test", 1),
+                          new PartitionInfo("test", 1, null, null, null))),
               false, // disable synchronous completion of send. see ProducerSendCompletionThread
               // below.
+              new DefaultPartitioner(),
               new IntegerSerializer(),
-              new LongSerializer()) {
+              valueSerializer) {
 
             // override flush() so that it does not complete all the waiting sends, giving a chance
             // to
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java
index 554c6d2fcaf..48b5b060a29 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.io.kafka;
 
+import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.BAD_RECORD_TAG;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 
@@ -41,15 +42,20 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
 import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
 import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.errorhandling.BadRecord;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.DefaultErrorHandler;
 import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollection.IsBounded;
 import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TypeDescriptor;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Charsets;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
@@ -64,7 +70,9 @@ import org.apache.kafka.clients.consumer.OffsetResetStrategy;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.SerializationException;
 import org.apache.kafka.common.header.internals.RecordHeaders;
+import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.serialization.StringDeserializer;
 import org.checkerframework.checker.initialization.qual.Initialized;
 import org.checkerframework.checker.nullness.qual.NonNull;
@@ -80,19 +88,22 @@ public class ReadFromKafkaDoFnTest {
 
   private final TopicPartition topicPartition = new TopicPartition("topic", 0);
 
+  private static final TupleTag<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> RECORDS =
+      new TupleTag<>();
+
   @Rule public ExpectedException thrown = ExpectedException.none();
 
   private final SimpleMockKafkaConsumer consumer =
       new SimpleMockKafkaConsumer(OffsetResetStrategy.NONE, topicPartition);
 
   private final ReadFromKafkaDoFn<String, String> dofnInstance =
-      ReadFromKafkaDoFn.create(makeReadSourceDescriptor(consumer));
+      ReadFromKafkaDoFn.create(makeReadSourceDescriptor(consumer), RECORDS);
 
   private final ExceptionMockKafkaConsumer exceptionConsumer =
       new ExceptionMockKafkaConsumer(OffsetResetStrategy.NONE, topicPartition);
 
   private final ReadFromKafkaDoFn<String, String> exceptionDofnInstance =
-      ReadFromKafkaDoFn.create(makeReadSourceDescriptor(exceptionConsumer));
+      ReadFromKafkaDoFn.create(makeReadSourceDescriptor(exceptionConsumer), RECORDS);
 
   private ReadSourceDescriptors<String, String> makeReadSourceDescriptor(
       Consumer<byte[], byte[]> kafkaMockConsumer) {
@@ -109,6 +120,31 @@ public class ReadFromKafkaDoFnTest {
         .withBootstrapServers("bootstrap_server");
   }
 
+  private ReadSourceDescriptors<String, String> makeFailingReadSourceDescriptor(
+      Consumer<byte[], byte[]> kafkaMockConsumer) {
+    return ReadSourceDescriptors.<String, String>read()
+        .withKeyDeserializer(FailingDeserializer.class)
+        .withValueDeserializer(FailingDeserializer.class)
+        .withConsumerFactoryFn(
+            new SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>>() {
+              @Override
+              public Consumer<byte[], byte[]> apply(Map<String, Object> input) {
+                return kafkaMockConsumer;
+              }
+            })
+        .withBootstrapServers("bootstrap_server");
+  }
+
+  public static class FailingDeserializer implements Deserializer<String> {
+
+    public FailingDeserializer() {}
+
+    @Override
+    public String deserialize(String topic, byte[] data) {
+      throw new SerializationException("Intentional serialization exception");
+    }
+  }
+
   private static class ExceptionMockKafkaConsumer extends MockConsumer<byte[], byte[]> {
 
     private final TopicPartition topicPartition;
@@ -254,23 +290,57 @@ public class ReadFromKafkaDoFnTest {
     }
   }
 
-  private static class MockOutputReceiver
-      implements OutputReceiver<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> {
+  private static class MockMultiOutputReceiver implements MultiOutputReceiver {
+
+    MockOutputReceiver<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> mockOutputReceiver =
+        new MockOutputReceiver<>();
+
+    MockOutputReceiver<BadRecord> badOutputReceiver = new MockOutputReceiver<>();
+
+    @Override
+    public @UnknownKeyFor @NonNull @Initialized <T> OutputReceiver<T> get(
+        @UnknownKeyFor @NonNull @Initialized TupleTag<T> tag) {
+      if (RECORDS.equals(tag)) {
+        return (OutputReceiver<T>) mockOutputReceiver;
+      } else if (BAD_RECORD_TAG.equals(tag)) {
+        return (OutputReceiver<T>) badOutputReceiver;
+      } else {
+        throw new RuntimeException("Invalid Tag");
+      }
+    }
+
+    public List<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> getGoodRecords() {
+      return mockOutputReceiver.getOutputs();
+    }
 
-    private final List<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> records =
-        new ArrayList<>();
+    public List<BadRecord> getBadRecords() {
+      return badOutputReceiver.getOutputs();
+    }
 
     @Override
-    public void output(KV<KafkaSourceDescriptor, KafkaRecord<String, String>> output) {}
+    public @UnknownKeyFor @NonNull @Initialized <T>
+        OutputReceiver<@UnknownKeyFor @NonNull @Initialized Row> getRowReceiver(
+            @UnknownKeyFor @NonNull @Initialized TupleTag<T> tag) {
+      return null;
+    }
+  }
+
+  private static class MockOutputReceiver<T> implements OutputReceiver<T> {
+
+    private final List<T> records = new ArrayList<>();
+
+    @Override
+    public void output(T output) {
+      records.add(output);
+    }
 
     @Override
     public void outputWithTimestamp(
-        KV<KafkaSourceDescriptor, KafkaRecord<String, String>> output,
-        @UnknownKeyFor @NonNull @Initialized Instant timestamp) {
+        T output, @UnknownKeyFor @NonNull @Initialized Instant timestamp) {
       records.add(output);
     }
 
-    public List<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> getOutputs() {
+    public List<T> getOutputs() {
       return this.records;
     }
   }
@@ -386,7 +456,7 @@ public class ReadFromKafkaDoFnTest {
 
   @Test
   public void testProcessElement() throws Exception {
-    MockOutputReceiver receiver = new MockOutputReceiver();
+    MockMultiOutputReceiver receiver = new MockMultiOutputReceiver();
     consumer.setNumOfRecordsPerPoll(3L);
     long startOffset = 5L;
     OffsetRangeTracker tracker =
@@ -396,7 +466,8 @@ public class ReadFromKafkaDoFnTest {
     ProcessContinuation result = dofnInstance.processElement(descriptor, tracker, null, receiver);
     assertEquals(ProcessContinuation.stop(), result);
     assertEquals(
-        createExpectedRecords(descriptor, startOffset, 3, "key", "value"), receiver.getOutputs());
+        createExpectedRecords(descriptor, startOffset, 3, "key", "value"),
+        receiver.getGoodRecords());
   }
 
   @Test
@@ -406,7 +477,7 @@ public class ReadFromKafkaDoFnTest {
     MetricsContainerImpl container = new MetricsContainerImpl("any");
     MetricsEnvironment.setCurrentContainer(container);
 
-    MockOutputReceiver receiver = new MockOutputReceiver();
+    MockMultiOutputReceiver receiver = new MockMultiOutputReceiver();
     consumer.setNumOfRecordsPerPoll(numElements);
     OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0, numElements));
     KafkaSourceDescriptor descriptor =
@@ -427,7 +498,7 @@ public class ReadFromKafkaDoFnTest {
 
   @Test
   public void testProcessElementWithEmptyPoll() throws Exception {
-    MockOutputReceiver receiver = new MockOutputReceiver();
+    MockMultiOutputReceiver receiver = new MockMultiOutputReceiver();
     consumer.setNumOfRecordsPerPoll(-1);
     OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE));
     ProcessContinuation result =
@@ -437,12 +508,12 @@ public class ReadFromKafkaDoFnTest {
             null,
             receiver);
     assertEquals(ProcessContinuation.resume(), result);
-    assertTrue(receiver.getOutputs().isEmpty());
+    assertTrue(receiver.getGoodRecords().isEmpty());
   }
 
   @Test
   public void testProcessElementWhenTopicPartitionIsRemoved() throws Exception {
-    MockOutputReceiver receiver = new MockOutputReceiver();
+    MockMultiOutputReceiver receiver = new MockMultiOutputReceiver();
     consumer.setRemoved();
     consumer.setNumOfRecordsPerPoll(10);
     OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE));
@@ -457,7 +528,7 @@ public class ReadFromKafkaDoFnTest {
 
   @Test
   public void testProcessElementWhenTopicPartitionIsStopped() throws Exception {
-    MockOutputReceiver receiver = new MockOutputReceiver();
+    MockMultiOutputReceiver receiver = new MockMultiOutputReceiver();
     ReadFromKafkaDoFn<String, String> instance =
         ReadFromKafkaDoFn.create(
             makeReadSourceDescriptor(consumer)
@@ -470,7 +541,8 @@ public class ReadFromKafkaDoFnTest {
                         return true;
                       }
                     })
-                .build());
+                .build(),
+            RECORDS);
     instance.setup();
     consumer.setNumOfRecordsPerPoll(10);
     OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE));
@@ -489,7 +561,7 @@ public class ReadFromKafkaDoFnTest {
     thrown.expect(KafkaException.class);
     thrown.expectMessage("SeekException");
 
-    MockOutputReceiver receiver = new MockOutputReceiver();
+    MockMultiOutputReceiver receiver = new MockMultiOutputReceiver();
     OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE));
 
     exceptionDofnInstance.processElement(
@@ -499,6 +571,61 @@ public class ReadFromKafkaDoFnTest {
         receiver);
   }
 
+  @Test
+  public void testProcessElementWithDeserializationExceptionDefaultRecordHandler()
+      throws Exception {
+    thrown.expect(SerializationException.class);
+    thrown.expectMessage("Intentional serialization exception");
+
+    MockMultiOutputReceiver receiver = new MockMultiOutputReceiver();
+    OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE));
+
+    consumer.setNumOfRecordsPerPoll(1);
+
+    ReadFromKafkaDoFn<String, String> dofnInstance =
+        ReadFromKafkaDoFn.create(makeFailingReadSourceDescriptor(consumer), RECORDS);
+
+    dofnInstance.setup();
+
+    dofnInstance.processElement(
+        KafkaSourceDescriptor.of(topicPartition, null, null, null, null, null),
+        tracker,
+        null,
+        receiver);
+
+    Assert.assertEquals("OutputRecordSize", 0, receiver.getGoodRecords().size());
+    Assert.assertEquals("OutputErrorSize", 0, receiver.getBadRecords().size());
+  }
+
+  @Test
+  public void testProcessElementWithDeserializationExceptionRecordingRecordHandler()
+      throws Exception {
+    MockMultiOutputReceiver receiver = new MockMultiOutputReceiver();
+    OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, 1L));
+
+    consumer.setNumOfRecordsPerPoll(1);
+
+    // Because we never actually execute the pipeline, no data will actually make it to the error
+    // handler. This will just configure the ReadSourceDesriptors to route the errors to the output
+    // PCollection instead of rethrowing.
+    ReadSourceDescriptors<String, String> descriptors =
+        makeFailingReadSourceDescriptor(consumer)
+            .withBadRecordErrorHandler(new DefaultErrorHandler<>());
+
+    ReadFromKafkaDoFn<String, String> dofnInstance = ReadFromKafkaDoFn.create(descriptors, RECORDS);
+
+    dofnInstance.setup();
+
+    dofnInstance.processElement(
+        KafkaSourceDescriptor.of(topicPartition, null, null, null, null, null),
+        tracker,
+        null,
+        receiver);
+
+    Assert.assertEquals("OutputRecordSize", 0, receiver.getGoodRecords().size());
+    Assert.assertEquals("OutputErrorSize", 1, receiver.getBadRecords().size());
+  }
+
   private static final TypeDescriptor<KafkaSourceDescriptor>
       KAFKA_SOURCE_DESCRIPTOR_TYPE_DESCRIPTOR = new TypeDescriptor<KafkaSourceDescriptor>() {};
 
@@ -522,7 +649,8 @@ public class ReadFromKafkaDoFnTest {
         .apply(
             ParDo.of(
                 ReadFromKafkaDoFn.<String, String>create(
-                    readSourceDescriptorsDecorator.apply(makeReadSourceDescriptor(consumer)))))
+                    readSourceDescriptorsDecorator.apply(makeReadSourceDescriptor(consumer)),
+                    RECORDS)))
         .setCoder(
             KvCoder.of(
                 SerializableCoder.of(KafkaSourceDescriptor.class),