You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by re...@apache.org on 2022/10/22 17:37:24 UTC

[beam] branch master updated: Merge pull request #23556: Forward failed storage-api row inserts to the failedStorageApiInserts PCollection addresses #23628

This is an automated email from the ASF dual-hosted git repository.

reuvenlax pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 8df6f67c65b Merge pull request #23556: Forward failed storage-api row inserts to the failedStorageApiInserts PCollection addresses #23628
8df6f67c65b is described below

commit 8df6f67c65b4888c45c31e088fb463972c4ec76b
Author: Reuven Lax <re...@google.com>
AuthorDate: Sat Oct 22 10:37:18 2022 -0700

    Merge pull request #23556: Forward failed storage-api row inserts to the failedStorageApiInserts PCollection addresses #23628
---
 .../org/apache/beam/gradle/BeamModulePlugin.groovy |   2 +-
 .../beam/sdk/io/gcp/bigquery/BigQueryOptions.java  |   6 +
 .../beam/sdk/io/gcp/bigquery/StorageApiLoads.java  | 100 +++---
 .../StorageApiWriteRecordsInconsistent.java        |  50 +--
 .../bigquery/StorageApiWriteUnshardedRecords.java  | 277 +++++++++++++----
 .../bigquery/StorageApiWritesShardedRecords.java   | 342 ++++++++++++++-------
 .../beam/sdk/io/gcp/testing/BigqueryClient.java    |   4 +-
 .../sdk/io/gcp/testing/FakeDatasetService.java     |  32 +-
 .../sdk/io/gcp/bigquery/BigQueryIOWriteTest.java   |  21 +-
 .../io/gcp/bigquery/BigQueryNestedRecordsIT.java   |   5 +-
 .../gcp/bigquery/StorageApiSinkFailedRowsIT.java   | 266 ++++++++++++++++
 .../gcp/bigquery/TableRowToStorageApiProtoIT.java  |   8 +-
 12 files changed, 865 insertions(+), 248 deletions(-)

diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 1f1fe4589ff..7f6ac755d6b 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -603,7 +603,7 @@ class BeamModulePlugin implements Plugin<Project> {
         google_cloud_pubsub                         : "com.google.cloud:google-cloud-pubsub", // google_cloud_platform_libraries_bom sets version
         google_cloud_pubsublite                     : "com.google.cloud:google-cloud-pubsublite",  // google_cloud_platform_libraries_bom sets version
         // The GCP Libraries BOM dashboard shows the versions set by the BOM:
-        // https://storage.googleapis.com/cloud-opensource-java-dashboard/com.google.cloud/libraries-bom/25.2.0/artifact_details.html
+        // https://storage.googleapis.com/cloud-opensource-java-dashboard/com.google.cloud/libraries-bom/26.1.3/artifact_details.html
         // Update libraries-bom version on sdks/java/container/license_scripts/dep_urls_java.yaml
         google_cloud_platform_libraries_bom         : "com.google.cloud:libraries-bom:26.1.3",
         google_cloud_spanner                        : "com.google.cloud:google-cloud-spanner", // google_cloud_platform_libraries_bom sets version
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryOptions.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryOptions.java
index 953d1237d9c..53cb2713641 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryOptions.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryOptions.java
@@ -150,4 +150,10 @@ public interface BigQueryOptions
   Integer getStorageApiAppendThresholdRecordCount();
 
   void setStorageApiAppendThresholdRecordCount(Integer value);
+
+  @Description("Maximum request size allowed by the storage write API. ")
+  @Default.Long(10 * 1000 * 1000)
+  Long getStorageWriteApiMaxRequestSize();
+
+  void setStorageWriteApiMaxRequestSize(Long value);
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java
index e48b9a19690..20ab251c9c0 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java
@@ -24,6 +24,7 @@ import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition;
 import org.apache.beam.sdk.schemas.NoSuchSchemaException;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.GroupIntoBatches;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
@@ -32,6 +33,7 @@ import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.util.ShardedKey;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.TupleTag;
 import org.joda.time.Duration;
@@ -101,7 +103,7 @@ public class StorageApiLoads<DestinationT, ElementT>
     PCollection<KV<DestinationT, ElementT>> inputInGlobalWindow =
         input.apply("rewindowIntoGlobal", Window.into(new GlobalWindows()));
 
-    PCollectionTuple convertedRecords =
+    PCollectionTuple convertMessagesResult =
         inputInGlobalWindow
             .apply(
                 "CreateTables",
@@ -116,20 +118,23 @@ public class StorageApiLoads<DestinationT, ElementT>
                     successfulRowsTag,
                     BigQueryStorageApiInsertErrorCoder.of(),
                     successCoder));
-    convertedRecords
-        .get(successfulRowsTag)
-        .apply(
-            "StorageApiWriteInconsistent",
-            new StorageApiWriteRecordsInconsistent<>(dynamicDestinations, bqServices));
+    PCollectionTuple writeRecordsResult =
+        convertMessagesResult
+            .get(successfulRowsTag)
+            .apply(
+                "StorageApiWriteInconsistent",
+                new StorageApiWriteRecordsInconsistent<>(
+                    dynamicDestinations,
+                    bqServices,
+                    failedRowsTag,
+                    BigQueryStorageApiInsertErrorCoder.of()));
+
+    PCollection<BigQueryStorageApiInsertError> insertErrors =
+        PCollectionList.of(convertMessagesResult.get(failedRowsTag))
+            .and(writeRecordsResult.get(failedRowsTag))
+            .apply("flattenErrors", Flatten.pCollections());
     return WriteResult.in(
-        input.getPipeline(),
-        null,
-        null,
-        null,
-        null,
-        null,
-        failedRowsTag,
-        convertedRecords.get(failedRowsTag));
+        input.getPipeline(), null, null, null, null, null, failedRowsTag, insertErrors);
   }
 
   public WriteResult expandTriggered(
@@ -139,7 +144,7 @@ public class StorageApiLoads<DestinationT, ElementT>
     // Handle triggered, low-latency loads into BigQuery.
     PCollection<KV<DestinationT, ElementT>> inputInGlobalWindow =
         input.apply("rewindowIntoGlobal", Window.into(new GlobalWindows()));
-    PCollectionTuple result =
+    PCollectionTuple convertMessagesResult =
         inputInGlobalWindow
             .apply(
                 "CreateTables",
@@ -159,7 +164,7 @@ public class StorageApiLoads<DestinationT, ElementT>
 
     if (this.allowAutosharding) {
       groupedRecords =
-          result
+          convertMessagesResult
               .get(successfulRowsTag)
               .apply(
                   "GroupIntoBatches",
@@ -171,7 +176,7 @@ public class StorageApiLoads<DestinationT, ElementT>
 
     } else {
       PCollection<KV<ShardedKey<DestinationT>, StorageApiWritePayload>> shardedRecords =
-          createShardedKeyValuePairs(result)
+          createShardedKeyValuePairs(convertMessagesResult)
               .setCoder(KvCoder.of(ShardedKey.Coder.of(destinationCoder), payloadCoder));
       groupedRecords =
           shardedRecords.apply(
@@ -181,20 +186,25 @@ public class StorageApiLoads<DestinationT, ElementT>
                       (StorageApiWritePayload e) -> (long) e.getPayload().length)
                   .withMaxBufferingDuration(triggeringFrequency));
     }
-    groupedRecords.apply(
-        "StorageApiWriteSharded",
-        new StorageApiWritesShardedRecords<>(
-            dynamicDestinations, createDisposition, kmsKey, bqServices, destinationCoder));
+    PCollectionTuple writeRecordsResult =
+        groupedRecords.apply(
+            "StorageApiWriteSharded",
+            new StorageApiWritesShardedRecords<>(
+                dynamicDestinations,
+                createDisposition,
+                kmsKey,
+                bqServices,
+                destinationCoder,
+                BigQueryStorageApiInsertErrorCoder.of(),
+                failedRowsTag));
+
+    PCollection<BigQueryStorageApiInsertError> insertErrors =
+        PCollectionList.of(convertMessagesResult.get(failedRowsTag))
+            .and(writeRecordsResult.get(failedRowsTag))
+            .apply("flattenErrors", Flatten.pCollections());
 
     return WriteResult.in(
-        input.getPipeline(),
-        null,
-        null,
-        null,
-        null,
-        null,
-        failedRowsTag,
-        result.get(failedRowsTag));
+        input.getPipeline(), null, null, null, null, null, failedRowsTag, insertErrors);
   }
 
   private PCollection<KV<ShardedKey<DestinationT>, StorageApiWritePayload>>
@@ -232,7 +242,7 @@ public class StorageApiLoads<DestinationT, ElementT>
     PCollection<KV<DestinationT, ElementT>> inputInGlobalWindow =
         input.apply(
             "rewindowIntoGlobal", Window.<KV<DestinationT, ElementT>>into(new GlobalWindows()));
-    PCollectionTuple convertedRecords =
+    PCollectionTuple convertMessagesResult =
         inputInGlobalWindow
             .apply(
                 "CreateTables",
@@ -247,20 +257,24 @@ public class StorageApiLoads<DestinationT, ElementT>
                     successfulRowsTag,
                     BigQueryStorageApiInsertErrorCoder.of(),
                     successCoder));
-    convertedRecords
-        .get(successfulRowsTag)
-        .apply(
-            "StorageApiWriteUnsharded",
-            new StorageApiWriteUnshardedRecords<>(dynamicDestinations, bqServices));
+
+    PCollectionTuple writeRecordsResult =
+        convertMessagesResult
+            .get(successfulRowsTag)
+            .apply(
+                "StorageApiWriteUnsharded",
+                new StorageApiWriteUnshardedRecords<>(
+                    dynamicDestinations,
+                    bqServices,
+                    failedRowsTag,
+                    BigQueryStorageApiInsertErrorCoder.of()));
+
+    PCollection<BigQueryStorageApiInsertError> insertErrors =
+        PCollectionList.of(convertMessagesResult.get(failedRowsTag))
+            .and(writeRecordsResult.get(failedRowsTag))
+            .apply("flattenErrors", Flatten.pCollections());
 
     return WriteResult.in(
-        input.getPipeline(),
-        null,
-        null,
-        null,
-        null,
-        null,
-        failedRowsTag,
-        convertedRecords.get(failedRowsTag));
+        input.getPipeline(), null, null, null, null, null, failedRowsTag, insertErrors);
   }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java
index 35b3ddfd080..190525925ae 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java
@@ -17,12 +17,14 @@
  */
 package org.apache.beam.sdk.io.gcp.bigquery;
 
-import org.apache.beam.sdk.coders.VoidCoder;
-import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
 
 /**
  * A transform to write sharded records to BigQuery using the Storage API. This transform uses the
@@ -32,34 +34,46 @@ import org.apache.beam.sdk.values.PCollection;
  */
 @SuppressWarnings("FutureReturnValueIgnored")
 public class StorageApiWriteRecordsInconsistent<DestinationT, ElementT>
-    extends PTransform<PCollection<KV<DestinationT, StorageApiWritePayload>>, PCollection<Void>> {
+    extends PTransform<PCollection<KV<DestinationT, StorageApiWritePayload>>, PCollectionTuple> {
   private final StorageApiDynamicDestinations<ElementT, DestinationT> dynamicDestinations;
   private final BigQueryServices bqServices;
+  private final TupleTag<BigQueryStorageApiInsertError> failedRowsTag;
+  private final TupleTag<KV<String, String>> finalizeTag = new TupleTag<>("finalizeTag");
+  private final Coder<BigQueryStorageApiInsertError> failedRowsCoder;
 
   public StorageApiWriteRecordsInconsistent(
       StorageApiDynamicDestinations<ElementT, DestinationT> dynamicDestinations,
-      BigQueryServices bqServices) {
+      BigQueryServices bqServices,
+      TupleTag<BigQueryStorageApiInsertError> failedRowsTag,
+      Coder<BigQueryStorageApiInsertError> failedRowsCoder) {
     this.dynamicDestinations = dynamicDestinations;
     this.bqServices = bqServices;
+    this.failedRowsTag = failedRowsTag;
+    this.failedRowsCoder = failedRowsCoder;
   }
 
   @Override
-  public PCollection<Void> expand(PCollection<KV<DestinationT, StorageApiWritePayload>> input) {
+  public PCollectionTuple expand(PCollection<KV<DestinationT, StorageApiWritePayload>> input) {
     String operationName = input.getName() + "/" + getName();
     BigQueryOptions bigQueryOptions = input.getPipeline().getOptions().as(BigQueryOptions.class);
     // Append records to the Storage API streams.
-    input.apply(
-        "Write Records",
-        ParDo.of(
-                new StorageApiWriteUnshardedRecords.WriteRecordsDoFn<>(
-                    operationName,
-                    dynamicDestinations,
-                    bqServices,
-                    true,
-                    bigQueryOptions.getStorageApiAppendThresholdBytes(),
-                    bigQueryOptions.getStorageApiAppendThresholdRecordCount(),
-                    bigQueryOptions.getNumStorageWriteApiStreamAppendClients()))
-            .withSideInputs(dynamicDestinations.getSideInputs()));
-    return input.getPipeline().apply("voids", Create.empty(VoidCoder.of()));
+    PCollectionTuple result =
+        input.apply(
+            "Write Records",
+            ParDo.of(
+                    new StorageApiWriteUnshardedRecords.WriteRecordsDoFn<>(
+                        operationName,
+                        dynamicDestinations,
+                        bqServices,
+                        true,
+                        bigQueryOptions.getStorageApiAppendThresholdBytes(),
+                        bigQueryOptions.getStorageApiAppendThresholdRecordCount(),
+                        bigQueryOptions.getNumStorageWriteApiStreamAppendClients(),
+                        finalizeTag,
+                        failedRowsTag))
+                .withOutputTags(finalizeTag, TupleTagList.of(failedRowsTag))
+                .withSideInputs(dynamicDestinations.getSideInputs()));
+    result.get(failedRowsTag).setCoder(failedRowsCoder);
+    return result;
   }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java
index 871fc73698a..0f86b8871f0 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java
@@ -20,26 +20,31 @@ package org.apache.beam.sdk.io.gcp.bigquery;
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
 
 import com.google.api.core.ApiFuture;
+import com.google.api.core.ApiFutures;
+import com.google.api.services.bigquery.model.TableRow;
 import com.google.cloud.bigquery.storage.v1.AppendRowsResponse;
+import com.google.cloud.bigquery.storage.v1.Exceptions;
 import com.google.cloud.bigquery.storage.v1.ProtoRows;
 import com.google.cloud.bigquery.storage.v1.WriteStream.Type;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.DynamicMessage;
+import com.google.protobuf.InvalidProtocolBufferException;
 import java.io.IOException;
 import java.time.Instant;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.Random;
+import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
-import java.util.stream.StreamSupport;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StreamAppendClient;
-import org.apache.beam.sdk.io.gcp.bigquery.RetryManager.Operation.Context;
 import org.apache.beam.sdk.io.gcp.bigquery.RetryManager.RetryType;
 import org.apache.beam.sdk.io.gcp.bigquery.StorageApiDynamicDestinations.DescriptorWrapper;
 import org.apache.beam.sdk.io.gcp.bigquery.StorageApiDynamicDestinations.MessageConverter;
@@ -51,14 +56,18 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Reshuffle;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.util.Preconditions;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Cache;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.RemovalNotification;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
 import org.checkerframework.checker.nullness.qual.NonNull;
@@ -75,11 +84,14 @@ import org.slf4j.LoggerFactory;
  */
 @SuppressWarnings({"FutureReturnValueIgnored"})
 public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
-    extends PTransform<PCollection<KV<DestinationT, StorageApiWritePayload>>, PCollection<Void>> {
+    extends PTransform<PCollection<KV<DestinationT, StorageApiWritePayload>>, PCollectionTuple> {
   private static final Logger LOG = LoggerFactory.getLogger(StorageApiWriteUnshardedRecords.class);
 
   private final StorageApiDynamicDestinations<ElementT, DestinationT> dynamicDestinations;
   private final BigQueryServices bqServices;
+  private final TupleTag<BigQueryStorageApiInsertError> failedRowsTag;
+  private final TupleTag<KV<String, String>> finalizeTag = new TupleTag<>("finalizeTag");
+  private final Coder<BigQueryStorageApiInsertError> failedRowsCoder;
   private static final ExecutorService closeWriterExecutor = Executors.newCachedThreadPool();
 
   /**
@@ -87,6 +99,8 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
    * StreamAppendClient after looking up the cache, and we must ensure that the cache is not
    * accessed in between the lookup and the pin (any access of the cache could trigger element
    * expiration). Therefore most used of APPEND_CLIENTS should synchronize.
+   *
+   * <p>TODO(reuvenlax); Once all uses of StreamWriter are using
    */
   private static final Cache<String, StreamAppendClient> APPEND_CLIENTS =
       CacheBuilder.newBuilder()
@@ -122,20 +136,24 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
 
   public StorageApiWriteUnshardedRecords(
       StorageApiDynamicDestinations<ElementT, DestinationT> dynamicDestinations,
-      BigQueryServices bqServices) {
+      BigQueryServices bqServices,
+      TupleTag<BigQueryStorageApiInsertError> failedRowsTag,
+      Coder<BigQueryStorageApiInsertError> failedRowsCoder) {
     this.dynamicDestinations = dynamicDestinations;
     this.bqServices = bqServices;
+    this.failedRowsTag = failedRowsTag;
+    this.failedRowsCoder = failedRowsCoder;
   }
 
   @Override
-  public PCollection<Void> expand(PCollection<KV<DestinationT, StorageApiWritePayload>> input) {
+  public PCollectionTuple expand(PCollection<KV<DestinationT, StorageApiWritePayload>> input) {
     String operationName = input.getName() + "/" + getName();
     BigQueryOptions options = input.getPipeline().getOptions().as(BigQueryOptions.class);
     org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument(
         !options.getUseStorageApiConnectionPool(),
         "useStorageApiConnectionPool only supported " + "when using STORAGE_API_AT_LEAST_ONCE");
-    return input
-        .apply(
+    PCollectionTuple writeResults =
+        input.apply(
             "Write Records",
             ParDo.of(
                     new WriteRecordsDoFn<>(
@@ -145,19 +163,39 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
                         false,
                         options.getStorageApiAppendThresholdBytes(),
                         options.getStorageApiAppendThresholdRecordCount(),
-                        options.getNumStorageWriteApiStreamAppendClients()))
-                .withSideInputs(dynamicDestinations.getSideInputs()))
+                        options.getNumStorageWriteApiStreamAppendClients(),
+                        finalizeTag,
+                        failedRowsTag))
+                .withOutputTags(finalizeTag, TupleTagList.of(failedRowsTag))
+                .withSideInputs(dynamicDestinations.getSideInputs()));
+
+    writeResults
+        .get(finalizeTag)
         .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))
         // Calling Reshuffle makes the output stable - once this completes, the append operations
         // will not retry.
         // TODO(reuvenlax): This should use RequiresStableInput instead.
         .apply("Reshuffle", Reshuffle.of())
         .apply("Finalize writes", ParDo.of(new StorageApiFinalizeWritesDoFn(bqServices)));
+    writeResults.get(failedRowsTag).setCoder(failedRowsCoder);
+    return writeResults;
   }
 
   static class WriteRecordsDoFn<DestinationT extends @NonNull Object, ElementT>
       extends DoFn<KV<DestinationT, StorageApiWritePayload>, KV<String, String>> {
     private final Counter forcedFlushes = Metrics.counter(WriteRecordsDoFn.class, "forcedFlushes");
+    private final TupleTag<KV<String, String>> finalizeTag;
+    private final TupleTag<BigQueryStorageApiInsertError> failedRowsTag;
+
+    static class AppendRowsContext extends RetryManager.Operation.Context<AppendRowsResponse> {
+      long offset;
+      ProtoRows protoRows;
+
+      public AppendRowsContext(long offset, ProtoRows protoRows) {
+        this.offset = offset;
+        this.protoRows = protoRows;
+      }
+    }
 
     class DestinationState {
       private final String tableUrn;
@@ -175,11 +213,17 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
           Metrics.counter(WriteRecordsDoFn.class, "schemaMismatches");
       private final Distribution inflightWaitSecondsDistribution =
           Metrics.distribution(WriteRecordsDoFn.class, "streamWriterWaitSeconds");
+      private final Counter rowsSentToFailedRowsCollection =
+          Metrics.counter(
+              StorageApiWritesShardedRecords.WriteRecordsDoFn.class,
+              "rowsSentToFailedRowsCollection");
+
       private final boolean useDefaultStream;
       private DescriptorWrapper descriptorWrapper;
       private Instant nextCacheTickle = Instant.MAX;
       private final int clientNumber;
       private final boolean usingMultiplexing;
+      private final long maxRequestSize;
 
       public DestinationState(
           String tableUrn,
@@ -187,7 +231,8 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
           DatasetService datasetService,
           boolean useDefaultStream,
           int streamAppendClientCount,
-          BigQueryOptions bigQueryOptions) {
+          boolean usingMultiplexing,
+          long maxRequestSize) {
         this.tableUrn = tableUrn;
         this.messageConverter = messageConverter;
         this.pendingMessages = Lists.newArrayList();
@@ -195,7 +240,8 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
         this.useDefaultStream = useDefaultStream;
         this.descriptorWrapper = messageConverter.getSchemaDescriptor();
         this.clientNumber = new Random().nextInt(streamAppendClientCount);
-        this.usingMultiplexing = bigQueryOptions.getUseStorageApiConnectionPool();
+        this.usingMultiplexing = usingMultiplexing;
+        this.maxRequestSize = maxRequestSize;
       }
 
       void teardown() {
@@ -217,7 +263,7 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
         return this.streamName;
       }
 
-      String createStreamIfNeeded() {
+      String getOrCreateStreamName() {
         try {
           if (!useDefaultStream) {
             this.streamName =
@@ -242,7 +288,7 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
       StreamAppendClient getStreamAppendClient(boolean lookupCache) {
         try {
           if (this.streamAppendClient == null) {
-            createStreamIfNeeded();
+            getOrCreateStreamName();
             final StreamAppendClient newStreamAppendClient;
             synchronized (APPEND_CLIENTS) {
               if (lookupCache) {
@@ -313,7 +359,8 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
           invalidateWriteStream();
           if (useDefaultStream) {
             // Since the default stream client is shared across many bundles and threads, we can't
-            // simply look it upfrom the cache, as another thread may have recreated it with the old
+            // simply look it up from the cache, as another thread may have recreated it with the
+            // old
             // schema.
             getStreamAppendClient(false);
           }
@@ -328,29 +375,62 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
         pendingMessages.add(ByteString.copyFrom(payload.getPayload()));
       }
 
-      void flush(RetryManager<AppendRowsResponse, Context<AppendRowsResponse>> retryManager)
+      long flush(
+          RetryManager<AppendRowsResponse, AppendRowsContext> retryManager,
+          OutputReceiver<BigQueryStorageApiInsertError> failedRowsReceiver)
           throws Exception {
         if (pendingMessages.isEmpty()) {
-          return;
+          return 0;
         }
-        final ProtoRows.Builder inserts = ProtoRows.newBuilder();
-        inserts.addAllSerializedRows(pendingMessages);
 
-        ProtoRows protoRows = inserts.build();
+        final ProtoRows.Builder insertsBuilder = ProtoRows.newBuilder();
+        insertsBuilder.addAllSerializedRows(pendingMessages);
+        final ProtoRows inserts = insertsBuilder.build();
         pendingMessages.clear();
 
+        // Handle the case where the request is too large.
+        if (inserts.getSerializedSize() >= maxRequestSize) {
+          if (inserts.getSerializedRowsCount() > 1) {
+            // TODO(reuvenlax): Is it worth trying to handle this case by splitting the protoRows?
+            // Given that we split
+            // the ProtoRows iterable at 2MB and the max request size is 10MB, this scenario seems
+            // nearly impossible.
+            LOG.error(
+                "A request containing more than one row is over the request size limit of "
+                    + maxRequestSize
+                    + ". This is unexpected. All rows in the request will be sent to the failed-rows PCollection.");
+          }
+          for (ByteString rowBytes : inserts.getSerializedRowsList()) {
+            TableRow failedRow =
+                TableRowToStorageApiProto.tableRowFromMessage(
+                    DynamicMessage.parseFrom(descriptorWrapper.descriptor, rowBytes));
+            failedRowsReceiver.output(
+                new BigQueryStorageApiInsertError(
+                    failedRow, "Row payload too large. Maximum size " + maxRequestSize));
+          }
+          return 0;
+        }
+
+        long offset = -1;
+        if (!this.useDefaultStream) {
+          offset = this.currentOffset;
+          this.currentOffset += inserts.getSerializedRowsCount();
+        }
+        AppendRowsContext appendRowsContext = new AppendRowsContext(offset, inserts);
+
         retryManager.addOperation(
             c -> {
+              if (c.protoRows.getSerializedRowsCount() == 0) {
+                // This might happen if all rows in a batch failed and were sent to the failed-rows
+                // PCollection.
+                return ApiFutures.immediateFuture(AppendRowsResponse.newBuilder().build());
+              }
               try {
                 StreamAppendClient writeStream = getStreamAppendClient(true);
-                long offset = -1;
-                if (!this.useDefaultStream) {
-                  offset = this.currentOffset;
-                  this.currentOffset += inserts.getSerializedRowsCount();
-                }
-                ApiFuture<AppendRowsResponse> response = writeStream.appendRows(offset, protoRows);
+                ApiFuture<AppendRowsResponse> response =
+                    writeStream.appendRows(c.offset, c.protoRows);
+                inflightWaitSecondsDistribution.update(writeStream.getInflightWaitSeconds());
                 if (!usingMultiplexing) {
-                  inflightWaitSecondsDistribution.update(writeStream.getInflightWaitSeconds());
                   if (writeStream.getInflightWaitSeconds() > 5) {
                     LOG.warn(
                         "Storage Api write delay more than {} seconds.",
@@ -363,33 +443,78 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
               }
             },
             contexts -> {
+              AppendRowsContext failedContext =
+                  Preconditions.checkStateNotNull(Iterables.getFirst(contexts, null));
+              if (failedContext.getError() != null
+                  && failedContext.getError() instanceof Exceptions.AppendSerializtionError) {
+                Exceptions.AppendSerializtionError error =
+                    Preconditions.checkStateNotNull(
+                        (Exceptions.AppendSerializtionError) failedContext.getError());
+                Set<Integer> failedRowIndices = error.getRowIndexToErrorMessage().keySet();
+                for (int failedIndex : failedRowIndices) {
+                  // Convert the message to a TableRow and send it to the failedRows collection.
+                  ByteString protoBytes = failedContext.protoRows.getSerializedRows(failedIndex);
+                  try {
+                    TableRow failedRow =
+                        TableRowToStorageApiProto.tableRowFromMessage(
+                            DynamicMessage.parseFrom(descriptorWrapper.descriptor, protoBytes));
+                    new BigQueryStorageApiInsertError(
+                        failedRow, error.getRowIndexToErrorMessage().get(failedIndex));
+                    failedRowsReceiver.output(
+                        new BigQueryStorageApiInsertError(
+                            failedRow, error.getRowIndexToErrorMessage().get(failedIndex)));
+                  } catch (InvalidProtocolBufferException e) {
+                    LOG.error("Failed to insert row and could not parse the result!");
+                  }
+                }
+                rowsSentToFailedRowsCollection.inc(failedRowIndices.size());
+
+                // Remove the failed row from the payload, so we retry the batch without the failed
+                // rows.
+                ProtoRows.Builder retryRows = ProtoRows.newBuilder();
+                for (int i = 0; i < failedContext.protoRows.getSerializedRowsCount(); ++i) {
+                  if (!failedRowIndices.contains(i)) {
+                    ByteString rowBytes = failedContext.protoRows.getSerializedRows(i);
+                    retryRows.addSerializedRows(rowBytes);
+                  }
+                }
+                failedContext.protoRows = retryRows.build();
+
+                // Since we removed rows, we need to update the insert offsets for all remaining
+                // rows.
+                long newOffset = failedContext.offset;
+                for (AppendRowsContext context : contexts) {
+                  context.offset = newOffset;
+                  newOffset += context.protoRows.getSerializedRowsCount();
+                }
+                this.currentOffset = newOffset;
+                return RetryType.RETRY_ALL_OPERATIONS;
+              }
+
               LOG.warn(
                   "Append to stream {} by client #{} failed with error, operations will be retried. Details: {}",
                   streamName,
                   clientNumber,
-                  retrieveErrorDetails(contexts));
+                  retrieveErrorDetails(failedContext));
               invalidateWriteStream();
               appendFailures.inc();
               return RetryType.RETRY_ALL_OPERATIONS;
             },
-            response -> {
-              recordsAppended.inc(protoRows.getSerializedRowsCount());
+            c -> {
+              recordsAppended.inc(c.protoRows.getSerializedRowsCount());
             },
-            new Context<>());
+            appendRowsContext);
         maybeTickleCache();
+        return inserts.getSerializedRowsCount();
       }
 
-      String retrieveErrorDetails(Iterable<Context<AppendRowsResponse>> contexts) {
-        return StreamSupport.stream(contexts.spliterator(), false)
-            .<@Nullable Throwable>map(ctx -> ctx.getError())
-            .map(
-                err ->
-                    (err == null)
-                        ? "no error"
-                        : Lists.newArrayList(err.getStackTrace()).stream()
-                            .map(se -> se.toString())
-                            .collect(Collectors.joining("\n")))
-            .collect(Collectors.joining(","));
+      String retrieveErrorDetails(AppendRowsContext failedContext) {
+        return (failedContext.getError() != null)
+            ? Arrays.stream(
+                    Preconditions.checkStateNotNull(failedContext.getError()).getStackTrace())
+                .map(StackTraceElement::toString)
+                .collect(Collectors.joining("\n"))
+            : "no execption";
       }
     }
 
@@ -412,7 +537,9 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
         boolean useDefaultStream,
         int flushThresholdBytes,
         int flushThresholdCount,
-        int streamAppendClientCount) {
+        int streamAppendClientCount,
+        TupleTag<KV<String, String>> finalizeTag,
+        TupleTag<BigQueryStorageApiInsertError> failedRowsTag) {
       this.messageConverters = new TwoLevelMessageConverterCache<>(operationName);
       this.dynamicDestinations = dynamicDestinations;
       this.bqServices = bqServices;
@@ -420,31 +547,47 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
       this.flushThresholdBytes = flushThresholdBytes;
       this.flushThresholdCount = flushThresholdCount;
       this.streamAppendClientCount = streamAppendClientCount;
+      this.finalizeTag = finalizeTag;
+      this.failedRowsTag = failedRowsTag;
     }
 
     boolean shouldFlush() {
       return numPendingRecords > flushThresholdCount || numPendingRecordBytes > flushThresholdBytes;
     }
 
-    void flushIfNecessary() throws Exception {
+    void flushIfNecessary(OutputReceiver<BigQueryStorageApiInsertError> failedRowsReceiver)
+        throws Exception {
       if (shouldFlush()) {
         forcedFlushes.inc();
         // Too much memory being used. Flush the state and wait for it to drain out.
         // TODO(reuvenlax): Consider waiting for memory usage to drop instead of waiting for all the
         // appends to finish.
-        flushAll();
+        flushAll(failedRowsReceiver);
       }
     }
 
-    void flushAll() throws Exception {
-      RetryManager<AppendRowsResponse, RetryManager.Operation.Context<AppendRowsResponse>>
-          retryManager =
-              new RetryManager<>(Duration.standardSeconds(1), Duration.standardSeconds(10), 1000);
-      Preconditions.checkStateNotNull(destinations);
-      for (DestinationState destinationState : destinations.values()) {
-        destinationState.flush(retryManager);
+    void flushAll(OutputReceiver<BigQueryStorageApiInsertError> failedRowsReceiver)
+        throws Exception {
+      List<RetryManager<AppendRowsResponse, AppendRowsContext>> retryManagers =
+          Lists.newArrayListWithCapacity(Preconditions.checkStateNotNull(destinations).size());
+      long numRowsWritten = 0;
+      for (DestinationState destinationState :
+          Preconditions.checkStateNotNull(destinations).values()) {
+        RetryManager<AppendRowsResponse, AppendRowsContext> retryManager =
+            new RetryManager<>(Duration.standardSeconds(1), Duration.standardSeconds(10), 1000);
+        retryManagers.add(retryManager);
+        numRowsWritten += destinationState.flush(retryManager, failedRowsReceiver);
+        retryManager.run(false);
+      }
+      if (numRowsWritten > 0) {
+        // TODO(reuvenlax): Can we await in parallel instead? Failure retries aren't triggered until
+        // await is called, so
+        // this approach means that if one call fais, it has to wait for all prior calls to complete
+        // before a retry happens.
+        for (RetryManager<AppendRowsResponse, AppendRowsContext> retryManager : retryManagers) {
+          retryManager.await();
+        }
       }
-      retryManager.run(true);
       numPendingRecords = 0;
       numPendingRecordBytes = 0;
     }
@@ -488,14 +631,16 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
           datasetService,
           useDefaultStream,
           streamAppendClientCount,
-          bigQueryOptions);
+          bigQueryOptions.getUseStorageApiConnectionPool(),
+          bigQueryOptions.getStorageWriteApiMaxRequestSize());
     }
 
     @ProcessElement
     public void process(
         ProcessContext c,
         PipelineOptions pipelineOptions,
-        @Element KV<DestinationT, StorageApiWritePayload> element)
+        @Element KV<DestinationT, StorageApiWritePayload> element,
+        MultiOutputReceiver o)
         throws Exception {
       DatasetService initializedDatasetService = initializeDatasetService(pipelineOptions);
       dynamicDestinations.setSideInputAccessorFromProcessContext(c);
@@ -506,7 +651,7 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
               k ->
                   createDestinationState(
                       c, k, initializedDatasetService, pipelineOptions.as(BigQueryOptions.class)));
-      flushIfNecessary();
+      flushIfNecessary(o.get(failedRowsTag));
       state.addMessage(element.getValue());
       ++numPendingRecords;
       numPendingRecordBytes += element.getValue().getPayload().length;
@@ -514,14 +659,28 @@ public class StorageApiWriteUnshardedRecords<DestinationT, ElementT>
 
     @FinishBundle
     public void finishBundle(FinishBundleContext context) throws Exception {
-      flushAll();
+      flushAll(
+          new OutputReceiver<BigQueryStorageApiInsertError>() {
+            @Override
+            public void output(BigQueryStorageApiInsertError output) {
+              outputWithTimestamp(output, GlobalWindow.INSTANCE.maxTimestamp());
+            }
+
+            @Override
+            public void outputWithTimestamp(
+                BigQueryStorageApiInsertError output, org.joda.time.Instant timestamp) {
+              context.output(failedRowsTag, output, timestamp, GlobalWindow.INSTANCE);
+            }
+          });
+
       final Map<DestinationT, DestinationState> destinations =
           Preconditions.checkStateNotNull(this.destinations);
       for (DestinationState state : destinations.values()) {
-        if (!useDefaultStream) {
+        if (!useDefaultStream && !Strings.isNullOrEmpty(state.streamName)) {
           context.output(
+              finalizeTag,
               KV.of(state.tableUrn, state.streamName),
-              BoundedWindow.TIMESTAMP_MAX_VALUE.minus(Duration.millis(1)),
+              GlobalWindow.INSTANCE.maxTimestamp(),
               GlobalWindow.INSTANCE);
         }
         state.teardown();
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java
index c8bb805b6e8..af0ae5169bc 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java
@@ -20,16 +20,23 @@ package org.apache.beam.sdk.io.gcp.bigquery;
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
 
 import com.google.api.core.ApiFuture;
+import com.google.api.core.ApiFutures;
+import com.google.api.services.bigquery.model.TableRow;
 import com.google.cloud.bigquery.storage.v1.AppendRowsResponse;
+import com.google.cloud.bigquery.storage.v1.Exceptions;
 import com.google.cloud.bigquery.storage.v1.Exceptions.StreamFinalizedException;
 import com.google.cloud.bigquery.storage.v1.ProtoRows;
 import com.google.cloud.bigquery.storage.v1.WriteStream.Type;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.DynamicMessage;
+import com.google.protobuf.InvalidProtocolBufferException;
 import io.grpc.Status;
 import io.grpc.Status.Code;
 import java.io.IOException;
 import java.time.Instant;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
@@ -74,6 +81,9 @@ import org.apache.beam.sdk.util.Preconditions;
 import org.apache.beam.sdk.util.ShardedKey;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
 import org.apache.beam.sdk.values.TypeDescriptor;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
@@ -99,7 +109,7 @@ import org.slf4j.LoggerFactory;
 public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object, ElementT>
     extends PTransform<
         PCollection<KV<ShardedKey<DestinationT>, Iterable<StorageApiWritePayload>>>,
-        PCollection<Void>> {
+        PCollectionTuple> {
   private static final Logger LOG = LoggerFactory.getLogger(StorageApiWritesShardedRecords.class);
   private static final Duration DEFAULT_STREAM_IDLE_TIME = Duration.standardHours(1);
 
@@ -108,7 +118,10 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
   private final String kmsKey;
   private final BigQueryServices bqServices;
   private final Coder<DestinationT> destinationCoder;
+  private final Coder<BigQueryStorageApiInsertError> failedRowsCoder;
   private final Duration streamIdleTime = DEFAULT_STREAM_IDLE_TIME;
+  private final TupleTag<BigQueryStorageApiInsertError> failedRowsTag;
+  private final TupleTag<KV<String, Operation>> flushTag = new TupleTag<>("flushTag");
   private static final ExecutorService closeWriterExecutor = Executors.newCachedThreadPool();
 
   private static final Cache<String, StreamAppendClient> APPEND_CLIENTS =
@@ -147,24 +160,29 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
       CreateDisposition createDisposition,
       String kmsKey,
       BigQueryServices bqServices,
-      Coder<DestinationT> destinationCoder) {
+      Coder<DestinationT> destinationCoder,
+      Coder<BigQueryStorageApiInsertError> failedRowsCoder,
+      TupleTag<BigQueryStorageApiInsertError> failedRowsTag) {
     this.dynamicDestinations = dynamicDestinations;
     this.createDisposition = createDisposition;
     this.kmsKey = kmsKey;
     this.bqServices = bqServices;
     this.destinationCoder = destinationCoder;
+    this.failedRowsCoder = failedRowsCoder;
+    this.failedRowsTag = failedRowsTag;
   }
 
   @Override
-  public PCollection<Void> expand(
+  public PCollectionTuple expand(
       PCollection<KV<ShardedKey<DestinationT>, Iterable<StorageApiWritePayload>>> input) {
     String operationName = input.getName() + "/" + getName();
     // Append records to the Storage API streams.
-    PCollection<KV<String, Operation>> written =
+    PCollectionTuple writeRecordsResult =
         input.apply(
             "Write Records",
             ParDo.of(new WriteRecordsDoFn(operationName, streamIdleTime))
-                .withSideInputs(dynamicDestinations.getSideInputs()));
+                .withSideInputs(dynamicDestinations.getSideInputs())
+                .withOutputTags(flushTag, TupleTagList.of(failedRowsTag)));
 
     SchemaCoder<Operation> operationCoder;
     try {
@@ -180,7 +198,8 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
     }
 
     // Send all successful writes to be flushed.
-    return written
+    writeRecordsResult
+        .get(flushTag)
         .setCoder(KvCoder.of(StringUtf8Coder.of(), operationCoder))
         .apply(
             Window.<KV<String, Operation>>configure()
@@ -192,6 +211,8 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
         .apply("maxFlushPosition", Combine.perKey(Max.naturalOrder(new Operation(-1, false))))
         .apply(
             "Flush and finalize writes", ParDo.of(new StorageApiFlushAndFinalizeDoFn(bqServices)));
+    writeRecordsResult.get(failedRowsTag).setCoder(failedRowsCoder);
+    return writeRecordsResult;
   }
 
   class WriteRecordsDoFn
@@ -215,6 +236,8 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
         Metrics.distribution(WriteRecordsDoFn.class, "appendSizeDistribution");
     private final Distribution appendSplitDistribution =
         Metrics.distribution(WriteRecordsDoFn.class, "appendSplitDistribution");
+    private final Counter rowsSentToFailedRowsCollection =
+        Metrics.counter(WriteRecordsDoFn.class, "rowsSentToFailedRowsCollection");
 
     private TwoLevelMessageConverterCache<DestinationT, ElementT> messageConverters;
 
@@ -297,8 +320,10 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
         final @AlwaysFetched @StateId("streamName") ValueState<String> streamName,
         final @AlwaysFetched @StateId("streamOffset") ValueState<Long> streamOffset,
         @TimerId("idleTimer") Timer idleTimer,
-        final OutputReceiver<KV<String, Operation>> o)
+        final MultiOutputReceiver o)
         throws Exception {
+      BigQueryOptions bigQueryOptions = pipelineOptions.as(BigQueryOptions.class);
+
       dynamicDestinations.setSideInputAccessorFromProcessContext(c);
       TableDestination tableDestination =
           destinations.computeIfAbsent(
@@ -323,7 +348,7 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
       // Each ProtoRows object contains at most 1MB of rows.
       // TODO: Push messageFromTableRow up to top level. That we we cans skip TableRow entirely if
       // already proto or already schema.
-      final long oneMb = 1024 * 1024;
+      final long splitSize = bigQueryOptions.getStorageApiAppendThresholdBytes();
       // Called if the schema does not match.
       Function<Long, DescriptorWrapper> updateSchemaHash =
           (Long expectedHash) -> {
@@ -343,7 +368,7 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
             }
           };
       Iterable<ProtoRows> messages =
-          new SplittingIterable(element.getValue(), oneMb, descriptor.get(), updateSchemaHash);
+          new SplittingIterable(element.getValue(), splitSize, descriptor.get(), updateSchemaHash);
 
       class AppendRowsContext extends RetryManager.Operation.Context<AppendRowsResponse> {
         final ShardedKey<DestinationT> key;
@@ -352,9 +377,11 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
         long offset = -1;
         long numRows = 0;
         long tryIteration = 0;
+        ProtoRows protoRows;
 
-        AppendRowsContext(ShardedKey<DestinationT> key) {
+        AppendRowsContext(ShardedKey<DestinationT> key, ProtoRows protoRows) {
           this.key = key;
+          this.protoRows = protoRows;
         }
 
         @Override
@@ -396,7 +423,7 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
                 context.client = appendClient;
                 context.offset = streamOffset.read();
                 ++context.tryIteration;
-                streamOffset.write(context.offset + context.numRows);
+                streamOffset.write(context.offset + context.protoRows.getSerializedRowsCount());
               }
             } catch (Exception e) {
               throw new RuntimeException(e);
@@ -415,114 +442,200 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
             }
           };
 
-      Instant now = Instant.now();
-      List<AppendRowsContext> contexts = Lists.newArrayList();
-      RetryManager<AppendRowsResponse, AppendRowsContext> retryManager =
-          new RetryManager<>(Duration.standardSeconds(1), Duration.standardSeconds(10), 1000);
-      int numSplits = 0;
-      for (ProtoRows protoRows : messages) {
-        ++numSplits;
-        Function<AppendRowsContext, ApiFuture<AppendRowsResponse>> run =
-            context -> {
-              try {
-                StreamAppendClient appendClient =
-                    APPEND_CLIENTS.get(
-                        context.streamName,
-                        () ->
-                            datasetService.getStreamAppendClient(
-                                context.streamName, descriptor.get().descriptor, false));
-                return appendClient.appendRows(context.offset, protoRows);
-              } catch (Exception e) {
-                throw new RuntimeException(e);
+      Function<AppendRowsContext, ApiFuture<AppendRowsResponse>> runOperation =
+          context -> {
+            if (context.protoRows.getSerializedRowsCount() == 0) {
+              // This might happen if all rows in a batch failed and were sent to the failed-rows
+              // PCollection.
+              return ApiFutures.immediateFuture(AppendRowsResponse.newBuilder().build());
+            }
+            try {
+              StreamAppendClient appendClient =
+                  APPEND_CLIENTS.get(
+                      context.streamName,
+                      () ->
+                          datasetService.getStreamAppendClient(
+                              context.streamName, descriptor.get().descriptor, false));
+              return appendClient.appendRows(context.offset, context.protoRows);
+            } catch (Exception e) {
+              throw new RuntimeException(e);
+            }
+          };
+
+      Function<Iterable<AppendRowsContext>, RetryType> onError =
+          failedContexts -> {
+            // The first context is always the one that fails.
+            AppendRowsContext failedContext =
+                Preconditions.checkStateNotNull(Iterables.getFirst(failedContexts, null));
+
+            // AppendSerializationError means that BigQuery detected errors on individual rows, e.g.
+            // a row not conforming
+            // to bigQuery invariants. These errors are persistent, so we redirect those rows to the
+            // failedInserts
+            // PCollection, and retry with the remaining rows.
+            if (failedContext.getError() != null
+                && failedContext.getError() instanceof Exceptions.AppendSerializtionError) {
+              Exceptions.AppendSerializtionError error =
+                  Preconditions.checkArgumentNotNull(
+                      (Exceptions.AppendSerializtionError) failedContext.getError());
+              Set<Integer> failedRowIndices = error.getRowIndexToErrorMessage().keySet();
+              for (int failedIndex : failedRowIndices) {
+                // Convert the message to a TableRow and send it to the failedRows collection.
+                ByteString protoBytes = failedContext.protoRows.getSerializedRows(failedIndex);
+                try {
+                  TableRow failedRow =
+                      TableRowToStorageApiProto.tableRowFromMessage(
+                          DynamicMessage.parseFrom(descriptor.get().descriptor, protoBytes));
+                  new BigQueryStorageApiInsertError(
+                      failedRow, error.getRowIndexToErrorMessage().get(failedIndex));
+                  o.get(failedRowsTag)
+                      .output(
+                          new BigQueryStorageApiInsertError(
+                              failedRow, error.getRowIndexToErrorMessage().get(failedIndex)));
+                } catch (InvalidProtocolBufferException e) {
+                  LOG.error("Failed to insert row and could not parse the result!");
+                }
               }
-            };
-
-        // RetryManager
-        Function<Iterable<AppendRowsContext>, RetryType> onError =
-            failedContexts -> {
-              // The first context is always the one that fails.
-              AppendRowsContext failedContext =
-                  Preconditions.checkStateNotNull(Iterables.getFirst(failedContexts, null));
-              // Invalidate the StreamWriter and force a new one to be created.
-              LOG.error(
-                  "Got error " + failedContext.getError() + " closing " + failedContext.streamName);
-              clearClients.accept(contexts);
-              appendFailures.inc();
-
-              boolean explicitStreamFinalized =
-                  failedContext.getError() instanceof StreamFinalizedException;
-              Throwable error = Preconditions.checkStateNotNull(failedContext.getError());
-              Status.Code statusCode = Status.fromThrowable(error).getCode();
-              // This means that the offset we have stored does not match the current end of
-              // the stream in the Storage API. Usually this happens because a crash or a bundle
-              // failure
-              // happened after an append but before the worker could checkpoint it's
-              // state. The records that were appended in a failed bundle will be retried,
-              // meaning that the unflushed tail of the stream must be discarded to prevent
-              // duplicates.
-              boolean offsetMismatch =
-                  statusCode.equals(Code.OUT_OF_RANGE) || statusCode.equals(Code.ALREADY_EXISTS);
-              // This implies that the stream doesn't exist or has already been finalized. In this
-              // case we have no choice but to create a new stream.
-              boolean streamDoesNotExist =
-                  explicitStreamFinalized
-                      || statusCode.equals(Code.INVALID_ARGUMENT)
-                      || statusCode.equals(Code.NOT_FOUND)
-                      || statusCode.equals(Code.FAILED_PRECONDITION);
-              if (offsetMismatch || streamDoesNotExist) {
-                appendOffsetFailures.inc();
-                LOG.warn(
-                    "Append to "
-                        + failedContext
-                        + " failed with "
-                        + failedContext.getError()
-                        + " Will retry with a new stream");
-                // Finalize the stream and clear streamName so a new stream will be created.
-                o.output(
-                    KV.of(failedContext.streamName, new Operation(failedContext.offset - 1, true)));
-                // Reinitialize all contexts with the new stream and new offsets.
-                initializeContexts.accept(failedContexts, true);
-
-                // Offset failures imply that all subsequent parallel appends will also fail.
-                // Retry them all.
-                return RetryType.RETRY_ALL_OPERATIONS;
+              rowsSentToFailedRowsCollection.inc(failedRowIndices.size());
+
+              // Remove the failed row from the payload, so we retry the batch without the failed
+              // rows.
+              ProtoRows.Builder retryRows = ProtoRows.newBuilder();
+              for (int i = 0; i < failedContext.protoRows.getSerializedRowsCount(); ++i) {
+                if (!failedRowIndices.contains(i)) {
+                  ByteString rowBytes = failedContext.protoRows.getSerializedRows(i);
+                  retryRows.addSerializedRows(rowBytes);
+                }
               }
+              failedContext.protoRows = retryRows.build();
 
+              // Since we removed rows, we need to update the insert offsets for all remaining rows.
+              long offset = failedContext.offset;
+              for (AppendRowsContext context : failedContexts) {
+                context.offset = offset;
+                offset += context.protoRows.getSerializedRowsCount();
+              }
+              streamOffset.write(offset);
               return RetryType.RETRY_ALL_OPERATIONS;
-            };
+            }
 
-        Consumer<AppendRowsContext> onSuccess =
-            context -> {
-              o.output(
-                  KV.of(
-                      context.streamName,
-                      new Operation(context.offset + context.numRows - 1, false)));
-              flushesScheduled.inc(protoRows.getSerializedRowsCount());
-            };
-
-        AppendRowsContext context = new AppendRowsContext(element.getKey());
-        context.numRows = protoRows.getSerializedRowsCount();
-        contexts.add(context);
-        retryManager.addOperation(run, onError, onSuccess, context);
-        recordsAppended.inc(protoRows.getSerializedRowsCount());
-        appendSizeDistribution.update(context.numRows);
-      }
-      initializeContexts.accept(contexts, false);
+            // Invalidate the StreamWriter and force a new one to be created.
+            LOG.error(
+                "Got error " + failedContext.getError() + " closing " + failedContext.streamName);
+            clearClients.accept(failedContexts);
+            appendFailures.inc();
+
+            boolean explicitStreamFinalized =
+                failedContext.getError() instanceof StreamFinalizedException;
+            Throwable error = Preconditions.checkStateNotNull(failedContext.getError());
+            Status.Code statusCode = Status.fromThrowable(error).getCode();
+            // This means that the offset we have stored does not match the current end of
+            // the stream in the Storage API. Usually this happens because a crash or a bundle
+            // failure
+            // happened after an append but before the worker could checkpoint it's
+            // state. The records that were appended in a failed bundle will be retried,
+            // meaning that the unflushed tail of the stream must be discarded to prevent
+            // duplicates.
+            boolean offsetMismatch =
+                statusCode.equals(Code.OUT_OF_RANGE) || statusCode.equals(Code.ALREADY_EXISTS);
+            // This implies that the stream doesn't exist or has already been finalized. In this
+            // case we have no choice but to create a new stream.
+            boolean streamDoesNotExist =
+                explicitStreamFinalized
+                    || statusCode.equals(Code.INVALID_ARGUMENT)
+                    || statusCode.equals(Code.NOT_FOUND)
+                    || statusCode.equals(Code.FAILED_PRECONDITION);
+            if (offsetMismatch || streamDoesNotExist) {
+              appendOffsetFailures.inc();
+              LOG.warn(
+                  "Append to "
+                      + failedContext
+                      + " failed with "
+                      + failedContext.getError()
+                      + " Will retry with a new stream");
+              // Finalize the stream and clear streamName so a new stream will be created.
+              o.get(flushTag)
+                  .output(
+                      KV.of(
+                          failedContext.streamName, new Operation(failedContext.offset - 1, true)));
+              // Reinitialize all contexts with the new stream and new offsets.
+              initializeContexts.accept(failedContexts, true);
+
+              // Offset failures imply that all subsequent parallel appends will also fail.
+              // Retry them all.
+              return RetryType.RETRY_ALL_OPERATIONS;
+            }
 
-      try {
-        retryManager.run(true);
-      } finally {
-        // Make sure that all pins are removed.
-        for (AppendRowsContext context : contexts) {
-          if (context.client != null) {
-            runAsyncIgnoreFailure(closeWriterExecutor, context.client::unpin);
+            return RetryType.RETRY_ALL_OPERATIONS;
+          };
+
+      Consumer<AppendRowsContext> onSuccess =
+          context -> {
+            o.get(flushTag)
+                .output(
+                    KV.of(
+                        context.streamName,
+                        new Operation(
+                            context.offset + context.protoRows.getSerializedRowsCount() - 1,
+                            false)));
+            flushesScheduled.inc(context.protoRows.getSerializedRowsCount());
+          };
+      long maxRequestSize = bigQueryOptions.getStorageWriteApiMaxRequestSize();
+      Instant now = Instant.now();
+      List<AppendRowsContext> contexts = Lists.newArrayList();
+      RetryManager<AppendRowsResponse, AppendRowsContext> retryManager =
+          new RetryManager<>(Duration.standardSeconds(1), Duration.standardSeconds(10), 1000);
+      int numAppends = 0;
+      for (ProtoRows protoRows : messages) {
+        // Handle the case of a row that is too large.
+        if (protoRows.getSerializedSize() >= maxRequestSize) {
+          if (protoRows.getSerializedRowsCount() > 1) {
+            // TODO(reuvenlax): Is it worth trying to handle this case by splitting the protoRows?
+            // Given that we split
+            // the ProtoRows iterable at 2MB and the max request size is 10MB, this scenario seems
+            // nearly impossible.
+            LOG.error(
+                "A request containing more than one row is over the request size limit of "
+                    + maxRequestSize
+                    + ". This is unexpected. All rows in the request will be sent to the failed-rows PCollection.");
+          }
+          for (ByteString rowBytes : protoRows.getSerializedRowsList()) {
+            TableRow failedRow =
+                TableRowToStorageApiProto.tableRowFromMessage(
+                    DynamicMessage.parseFrom(descriptor.get().descriptor, rowBytes));
+            o.get(failedRowsTag)
+                .output(
+                    new BigQueryStorageApiInsertError(
+                        failedRow, "Row payload too large. Maximum size " + maxRequestSize));
           }
+        } else {
+          ++numAppends;
+          // RetryManager
+          AppendRowsContext context = new AppendRowsContext(element.getKey(), protoRows);
+          contexts.add(context);
+          retryManager.addOperation(runOperation, onError, onSuccess, context);
+          recordsAppended.inc(protoRows.getSerializedRowsCount());
+          appendSizeDistribution.update(context.protoRows.getSerializedRowsCount());
         }
       }
-      appendSplitDistribution.update(numSplits);
 
-      java.time.Duration timeElapsed = java.time.Duration.between(now, Instant.now());
-      appendLatencyDistribution.update(timeElapsed.toMillis());
+      if (numAppends > 0) {
+        initializeContexts.accept(contexts, false);
+        try {
+          retryManager.run(true);
+        } finally {
+          // Make sure that all pins are removed.
+          for (AppendRowsContext context : contexts) {
+            if (context.client != null) {
+              runAsyncIgnoreFailure(closeWriterExecutor, context.client::unpin);
+            }
+          }
+        }
+        appendSplitDistribution.update(numAppends);
+
+        java.time.Duration timeElapsed = java.time.Duration.between(now, Instant.now());
+        appendLatencyDistribution.update(timeElapsed.toMillis());
+      }
       idleTimer.offset(streamIdleTime).withNoOutputTimestamp().setRelative();
     }
 
@@ -530,15 +643,16 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
     private void finalizeStream(
         @AlwaysFetched @StateId("streamName") ValueState<String> streamName,
         @AlwaysFetched @StateId("streamOffset") ValueState<Long> streamOffset,
-        OutputReceiver<KV<String, Operation>> o,
+        MultiOutputReceiver o,
         org.joda.time.Instant finalizeElementTs) {
       String stream = MoreObjects.firstNonNull(streamName.read(), "");
 
       if (!Strings.isNullOrEmpty(stream)) {
         // Finalize the stream
         long nextOffset = MoreObjects.firstNonNull(streamOffset.read(), 0L);
-        o.outputWithTimestamp(
-            KV.of(stream, new Operation(nextOffset - 1, true)), finalizeElementTs);
+        o.get(flushTag)
+            .outputWithTimestamp(
+                KV.of(stream, new Operation(nextOffset - 1, true)), finalizeElementTs);
         streamName.clear();
         streamOffset.clear();
         // Make sure that the stream object is closed.
@@ -550,7 +664,7 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
     public void onTimer(
         @AlwaysFetched @StateId("streamName") ValueState<String> streamName,
         @AlwaysFetched @StateId("streamOffset") ValueState<Long> streamOffset,
-        OutputReceiver<KV<String, Operation>> o,
+        MultiOutputReceiver o,
         BoundedWindow window) {
       // Stream is idle - clear it.
       // Note: this is best effort. We are explicitly emiting a timestamp that is before
@@ -566,7 +680,7 @@ public class StorageApiWritesShardedRecords<DestinationT extends @NonNull Object
     public void onWindowExpiration(
         @AlwaysFetched @StateId("streamName") ValueState<String> streamName,
         @AlwaysFetched @StateId("streamOffset") ValueState<Long> streamOffset,
-        OutputReceiver<KV<String, Operation>> o,
+        MultiOutputReceiver o,
         BoundedWindow window) {
       // Window is done - usually because the pipeline has been drained. Make sure to clean up
       // streams so that they are not leaked.
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/BigqueryClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/BigqueryClient.java
index 6224729aa91..f5752797acd 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/BigqueryClient.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/BigqueryClient.java
@@ -288,7 +288,8 @@ public class BigqueryClient {
 
   /** Performs a query without flattening results. */
   @Nonnull
-  public List<TableRow> queryUnflattened(String query, String projectId, boolean typed)
+  public List<TableRow> queryUnflattened(
+      String query, String projectId, boolean typed, boolean useStandardSql)
       throws IOException, InterruptedException {
     Random rnd = new Random(System.currentTimeMillis());
     String temporaryDatasetId = "_dataflow_temporary_dataset_" + rnd.nextInt(1000000);
@@ -308,6 +309,7 @@ public class BigqueryClient {
             .setFlattenResults(false)
             .setAllowLargeResults(true)
             .setDestinationTable(tempTableReference)
+            .setUseLegacySql(!useStandardSql)
             .setQuery(query);
     JobConfiguration jc = new JobConfiguration().setQuery(jcQuery);
 
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeDatasetService.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeDatasetService.java
index 44f73bd56cb..948c75cb756 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeDatasetService.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeDatasetService.java
@@ -32,6 +32,7 @@ import com.google.api.services.bigquery.model.TableRow;
 import com.google.api.services.bigquery.model.TableSchema;
 import com.google.cloud.bigquery.storage.v1.AppendRowsResponse;
 import com.google.cloud.bigquery.storage.v1.BatchCommitWriteStreamsResponse;
+import com.google.cloud.bigquery.storage.v1.Exceptions;
 import com.google.cloud.bigquery.storage.v1.FinalizeWriteStreamResponse;
 import com.google.cloud.bigquery.storage.v1.FlushRowsResponse;
 import com.google.cloud.bigquery.storage.v1.ProtoRows;
@@ -43,6 +44,7 @@ import com.google.protobuf.ByteString;
 import com.google.protobuf.Descriptors.Descriptor;
 import com.google.protobuf.DynamicMessage;
 import com.google.protobuf.Timestamp;
+import com.google.rpc.Code;
 import java.io.IOException;
 import java.io.Serializable;
 import java.util.HashMap;
@@ -50,6 +52,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
 import java.util.regex.Pattern;
 import javax.annotation.Nullable;
 import org.apache.beam.sdk.annotations.Internal;
@@ -148,6 +151,8 @@ public class FakeDatasetService implements DatasetService, Serializable {
     }
   }
 
+  Function<TableRow, Boolean> shouldFailRow =
+      (Function<TableRow, Boolean> & Serializable) tr -> false;
   Map<String, List<String>> insertErrors = Maps.newHashMap();
 
   // The counter for the number of insertions performed.
@@ -162,6 +167,10 @@ public class FakeDatasetService implements DatasetService, Serializable {
     }
   }
 
+  public void setShouldFailRow(Function<TableRow, Boolean> shouldFailRow) {
+    this.shouldFailRow = shouldFailRow;
+  }
+
   @Override
   public Table getTable(TableReference tableRef) throws InterruptedException, IOException {
     if (tableRef.getProjectId() == null) {
@@ -504,6 +513,7 @@ public class FakeDatasetService implements DatasetService, Serializable {
       @Override
       public ApiFuture<AppendRowsResponse> appendRows(long offset, ProtoRows rows)
           throws Exception {
+        AppendRowsResponse.Builder responseBuilder = AppendRowsResponse.newBuilder();
         synchronized (FakeDatasetService.class) {
           Stream stream = writeStreams.get(streamName);
           if (stream == null) {
@@ -511,18 +521,32 @@ public class FakeDatasetService implements DatasetService, Serializable {
           }
           List<TableRow> tableRows =
               Lists.newArrayListWithExpectedSize(rows.getSerializedRowsCount());
-          for (ByteString bytes : rows.getSerializedRowsList()) {
+          Map<Integer, String> rowIndexToErrorMessage = Maps.newHashMap();
+          for (int i = 0; i < rows.getSerializedRowsCount(); ++i) {
+            ByteString bytes = rows.getSerializedRows(i);
             DynamicMessage msg = DynamicMessage.parseFrom(protoDescriptor, bytes);
             if (msg.getUnknownFields() != null && !msg.getUnknownFields().asMap().isEmpty()) {
               throw new RuntimeException("Unknown fields set in append! " + msg.getUnknownFields());
             }
-            tableRows.add(
+            TableRow tableRow =
                 TableRowToStorageApiProto.tableRowFromMessage(
-                    DynamicMessage.parseFrom(protoDescriptor, bytes)));
+                    DynamicMessage.parseFrom(protoDescriptor, bytes));
+            if (shouldFailRow.apply(tableRow)) {
+              rowIndexToErrorMessage.put(i, "Failing row " + tableRow.toPrettyString());
+            }
+            tableRows.add(tableRow);
+          }
+          if (!rowIndexToErrorMessage.isEmpty()) {
+            return ApiFutures.immediateFailedFuture(
+                new Exceptions.AppendSerializtionError(
+                    Code.INVALID_ARGUMENT.getNumber(),
+                    "Append serialization failed for writer: " + streamName,
+                    stream.streamName,
+                    rowIndexToErrorMessage));
           }
           stream.appendRows(offset, tableRows);
         }
-        return ApiFutures.immediateFuture(AppendRowsResponse.newBuilder().build());
+        return ApiFutures.immediateFuture(responseBuilder.build());
       }
 
       @Override
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
index 7f529bfa348..1e1749e8569 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
@@ -64,6 +64,7 @@ import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ThreadLocalRandom;
+import java.util.function.Function;
 import java.util.function.LongFunction;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
@@ -2583,11 +2584,15 @@ public class BigQueryIOWriteTest implements Serializable {
     TableRow goodNested = new TableRow().set("number", "42");
     TableRow badNested = new TableRow().set("number", "nAn");
 
+    final String failValue = "failme";
     List<TableRow> goodRows =
         ImmutableList.of(
             new TableRow().set("name", "n1").set("number", "1"),
+            new TableRow().set("name", failValue).set("number", "1"),
             new TableRow().set("name", "n2").set("number", "2"),
-            new TableRow().set("name", "parent1").set("nested", goodNested));
+            new TableRow().set("name", failValue).set("number", "2"),
+            new TableRow().set("name", "parent1").set("nested", goodNested),
+            new TableRow().set("name", failValue).set("number", "1"));
     List<TableRow> badRows =
         ImmutableList.of(
             // Unknown field.
@@ -2614,6 +2619,11 @@ public class BigQueryIOWriteTest implements Serializable {
             // Invalid nested row
             new TableRow().set("name", "parent2").set("nested", badNested));
 
+    Function<TableRow, Boolean> shouldFailRow =
+        (Function<TableRow, Boolean> & Serializable)
+            tr -> tr.containsKey("name") && tr.get("name").equals(failValue);
+    fakeDatasetService.setShouldFailRow(shouldFailRow);
+
     WriteResult result =
         p.apply(Create.of(Iterables.concat(goodRows, badRows)))
             .apply(
@@ -2632,12 +2642,17 @@ public class BigQueryIOWriteTest implements Serializable {
             .apply(
                 MapElements.into(TypeDescriptor.of(TableRow.class))
                     .via(BigQueryStorageApiInsertError::getRow));
-    PAssert.that(deadRows).containsInAnyOrder(badRows);
+
+    PAssert.that(deadRows)
+        .containsInAnyOrder(
+            Iterables.concat(badRows, Iterables.filter(goodRows, shouldFailRow::apply)));
     p.run();
 
     assertThat(
         fakeDatasetService.getAllRows("project-id", "dataset-id", "table"),
-        containsInAnyOrder(Iterables.toArray(goodRows, TableRow.class)));
+        containsInAnyOrder(
+            Iterables.toArray(
+                Iterables.filter(goodRows, r -> !shouldFailRow.apply(r)), TableRow.class)));
   }
 
   @Test
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryNestedRecordsIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryNestedRecordsIT.java
index 698ef660293..b85dc62c5fe 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryNestedRecordsIT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryNestedRecordsIT.java
@@ -97,12 +97,13 @@ public class BigQueryNestedRecordsIT {
 
     TableRow queryUnflattened =
         bigQueryClient
-            .queryUnflattened(options.getInput(), bigQueryOptions.getProject(), true)
+            .queryUnflattened(options.getInput(), bigQueryOptions.getProject(), true, false)
             .get(0);
 
     TableRow queryUnflattenable =
         bigQueryClient
-            .queryUnflattened(options.getUnflattenableInput(), bigQueryOptions.getProject(), true)
+            .queryUnflattened(
+                options.getUnflattenableInput(), bigQueryOptions.getProject(), true, false)
             .get(0);
 
     // Verify that the results are the same.
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkFailedRowsIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkFailedRowsIT.java
new file mode 100644
index 00000000000..465bebbf138
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkFailedRowsIT.java
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigquery;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.google.api.services.bigquery.model.Table;
+import com.google.api.services.bigquery.model.TableFieldSchema;
+import com.google.api.services.bigquery.model.TableReference;
+import com.google.api.services.bigquery.model.TableRow;
+import com.google.api.services.bigquery.model.TableSchema;
+import java.io.IOException;
+import java.util.List;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
+import org.apache.beam.sdk.io.gcp.testing.BigqueryClient;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
+import org.hamcrest.Matchers;
+import org.joda.time.Duration;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Integration test for failed-rows handling when using the storage API. */
+@RunWith(Parameterized.class)
+public class StorageApiSinkFailedRowsIT {
+  @Parameterized.Parameters
+  public static Iterable<Object[]> data() {
+    return ImmutableList.of(
+        new Object[] {true, false, false},
+        new Object[] {false, true, false},
+        new Object[] {false, false, true},
+        new Object[] {true, false, true});
+  }
+
+  @Parameterized.Parameter(0)
+  public boolean useStreamingExactlyOnce;
+
+  @Parameterized.Parameter(1)
+  public boolean useAtLeastOnce;
+
+  @Parameterized.Parameter(2)
+  public boolean useBatch;
+
+  private static final Logger LOG = LoggerFactory.getLogger(StorageApiSinkFailedRowsIT.class);
+  private static final BigqueryClient BQ_CLIENT = new BigqueryClient("StorageApiSinkFailedRowsIT");
+  private static final String PROJECT =
+      TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject();
+  private static final String BIG_QUERY_DATASET_ID =
+      "storage_api_sink_failed_rows" + System.nanoTime();
+
+  private static final List<TableFieldSchema> FIELDS =
+      ImmutableList.<TableFieldSchema>builder()
+          .add(new TableFieldSchema().setType("STRING").setName("str"))
+          .add(new TableFieldSchema().setType("INT64").setName("i64"))
+          .add(new TableFieldSchema().setType("DATE").setName("date"))
+          .add(new TableFieldSchema().setType("STRING").setMaxLength(1L).setName("strone"))
+          .add(new TableFieldSchema().setType("BYTES").setName("bytes"))
+          .add(new TableFieldSchema().setType("JSON").setName("json"))
+          .add(
+              new TableFieldSchema()
+                  .setType("STRING")
+                  .setMaxLength(1L)
+                  .setMode("REPEATED")
+                  .setName("stronearray"))
+          .build();
+
+  private static final TableSchema BASE_TABLE_SCHEMA =
+      new TableSchema()
+          .setFields(
+              ImmutableList.<TableFieldSchema>builder()
+                  .addAll(FIELDS)
+                  .add(new TableFieldSchema().setType("STRUCT").setFields(FIELDS).setName("inner"))
+                  .build());
+
+  private static final byte[] BIG_BYTES = new byte[11 * 1024 * 1024];
+
+  private BigQueryIO.Write.Method getMethod() {
+    return useAtLeastOnce
+        ? BigQueryIO.Write.Method.STORAGE_API_AT_LEAST_ONCE
+        : BigQueryIO.Write.Method.STORAGE_WRITE_API;
+  }
+
+  @BeforeClass
+  public static void setUpTestEnvironment() throws IOException, InterruptedException {
+    // Create one BQ dataset for all test cases.
+    BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID);
+  }
+
+  @AfterClass
+  public static void cleanup() {
+    LOG.info("Start to clean up tables and datasets.");
+    BQ_CLIENT.deleteDataset(PROJECT, BIG_QUERY_DATASET_ID);
+  }
+
+  @Test
+  public void testSchemaMismatchCaughtByBeam() throws IOException, InterruptedException {
+    String tableSpec = createTable(BASE_TABLE_SCHEMA);
+    TableRow good1 = new TableRow().set("str", "foo").set("i64", "42");
+    TableRow good2 = new TableRow().set("str", "foo").set("i64", "43");
+    Iterable<TableRow> goodRows =
+        ImmutableList.of(
+            good1.clone().set("inner", new TableRow()),
+            good2.clone().set("inner", new TableRow()),
+            new TableRow().set("inner", good1),
+            new TableRow().set("inner", good2));
+
+    TableRow bad1 = new TableRow().set("str", "foo").set("i64", "baad");
+    TableRow bad2 = new TableRow().set("str", "foo").set("i64", "42").set("unknown", "foobar");
+    Iterable<TableRow> badRows =
+        ImmutableList.of(
+            bad1, bad2, new TableRow().set("inner", bad1), new TableRow().set("inner", bad2));
+
+    runPipeline(
+        getMethod(),
+        useStreamingExactlyOnce,
+        tableSpec,
+        Iterables.concat(goodRows, badRows),
+        badRows);
+    assertGoodRowsWritten(tableSpec, goodRows);
+  }
+
+  @Test
+  public void testInvalidRowCaughtByBigquery() throws IOException, InterruptedException {
+    String tableSpec = createTable(BASE_TABLE_SCHEMA);
+
+    TableRow good1 =
+        new TableRow()
+            .set("str", "foo")
+            .set("i64", "42")
+            .set("date", "2022-08-16")
+            .set("stronearray", Lists.newArrayList());
+    TableRow good2 =
+        new TableRow().set("str", "foo").set("i64", "43").set("stronearray", Lists.newArrayList());
+    Iterable<TableRow> goodRows =
+        ImmutableList.of(
+            good1.clone().set("inner", new TableRow().set("stronearray", Lists.newArrayList())),
+            good2.clone().set("inner", new TableRow().set("stronearray", Lists.newArrayList())),
+            new TableRow().set("inner", good1).set("stronearray", Lists.newArrayList()),
+            new TableRow().set("inner", good2).set("stronearray", Lists.newArrayList()));
+
+    TableRow bad1 = new TableRow().set("str", "foo").set("i64", "42").set("date", "10001-08-16");
+    TableRow bad2 = new TableRow().set("str", "foo").set("i64", "42").set("strone", "ab");
+    TableRow bad3 = new TableRow().set("str", "foo").set("i64", "42").set("json", "BAADF00D");
+    TableRow bad4 =
+        new TableRow()
+            .set("str", "foo")
+            .set("i64", "42")
+            .set("stronearray", Lists.newArrayList("toolong"));
+    TableRow bad5 = new TableRow().set("bytes", BIG_BYTES);
+    Iterable<TableRow> badRows =
+        ImmutableList.of(
+            bad1,
+            bad2,
+            bad3,
+            bad4,
+            bad5,
+            new TableRow().set("inner", bad1),
+            new TableRow().set("inner", bad2),
+            new TableRow().set("inner", bad3));
+
+    runPipeline(
+        getMethod(),
+        useStreamingExactlyOnce,
+        tableSpec,
+        Iterables.concat(goodRows, badRows),
+        badRows);
+    assertGoodRowsWritten(tableSpec, goodRows);
+  }
+
+  private static String createTable(TableSchema tableSchema)
+      throws IOException, InterruptedException {
+    String table = "table" + System.nanoTime();
+    BQ_CLIENT.deleteTable(PROJECT, BIG_QUERY_DATASET_ID, table);
+    BQ_CLIENT.createNewTable(
+        PROJECT,
+        BIG_QUERY_DATASET_ID,
+        new Table()
+            .setSchema(tableSchema)
+            .setTableReference(
+                new TableReference()
+                    .setTableId(table)
+                    .setDatasetId(BIG_QUERY_DATASET_ID)
+                    .setProjectId(PROJECT)));
+    return PROJECT + "." + BIG_QUERY_DATASET_ID + "." + table;
+  }
+
+  private void assertGoodRowsWritten(String tableSpec, Iterable<TableRow> goodRows)
+      throws IOException, InterruptedException {
+    TableRow queryResponse =
+        Iterables.getOnlyElement(
+            BQ_CLIENT.queryUnflattened(
+                String.format("SELECT COUNT(*) FROM %s", tableSpec), PROJECT, true, true));
+    int numRowsWritten = Integer.parseInt((String) queryResponse.get("f0_"));
+    if (useAtLeastOnce) {
+      assertThat(numRowsWritten, Matchers.greaterThanOrEqualTo(Iterables.size(goodRows)));
+    } else {
+      assertThat(numRowsWritten, Matchers.equalTo(Iterables.size(goodRows)));
+    }
+  }
+
+  private static void runPipeline(
+      BigQueryIO.Write.Method method,
+      boolean triggered,
+      String tableSpec,
+      Iterable<TableRow> tableRows,
+      Iterable<TableRow> expectedFailedRows) {
+    Pipeline p = Pipeline.create();
+
+    BigQueryIO.Write<TableRow> write =
+        BigQueryIO.writeTableRows()
+            .to(tableSpec)
+            .withSchema(BASE_TABLE_SCHEMA)
+            .withMethod(method)
+            .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_NEVER);
+    if (method == BigQueryIO.Write.Method.STORAGE_WRITE_API) {
+      write = write.withNumStorageWriteApiStreams(1);
+      if (triggered) {
+        write = write.withTriggeringFrequency(Duration.standardSeconds(1));
+      }
+    }
+    PCollection<TableRow> input = p.apply("Create test cases", Create.of(tableRows));
+    if (triggered) {
+      input = input.setIsBoundedInternal(PCollection.IsBounded.UNBOUNDED);
+    }
+    WriteResult result = input.apply("Write using Storage Write API", write);
+
+    PCollection<TableRow> failedRows =
+        result
+            .getFailedStorageApiInserts()
+            .apply(
+                MapElements.into(TypeDescriptor.of(TableRow.class))
+                    .via(BigQueryStorageApiInsertError::getRow));
+
+    PAssert.that(failedRows).containsInAnyOrder(expectedFailedRows);
+
+    p.run().waitUntilFinish();
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoIT.java
index b2d9e04ffe2..5f488da0210 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoIT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoIT.java
@@ -337,7 +337,8 @@ public class TableRowToStorageApiProtoIT {
     runPipeline(tableSpec, Collections.singleton(BASE_TABLE_ROW));
 
     List<TableRow> actualTableRows =
-        BQ_CLIENT.queryUnflattened(String.format("SELECT * FROM [%s]", tableSpec), PROJECT, true);
+        BQ_CLIENT.queryUnflattened(
+            String.format("SELECT * FROM %s", tableSpec), PROJECT, true, true);
 
     assertEquals(1, actualTableRows.size());
     assertEquals(BASE_TABLE_ROW_EXPECTED, actualTableRows.get(0));
@@ -362,7 +363,8 @@ public class TableRowToStorageApiProtoIT {
     runPipeline(tableSpec, Collections.singleton(tableRow));
 
     List<TableRow> actualTableRows =
-        BQ_CLIENT.queryUnflattened(String.format("SELECT * FROM [%s]", tableSpec), PROJECT, true);
+        BQ_CLIENT.queryUnflattened(
+            String.format("SELECT * FROM %s", tableSpec), PROJECT, true, true);
 
     assertEquals(1, actualTableRows.size());
     assertEquals(BASE_TABLE_ROW_EXPECTED, actualTableRows.get(0).get("nestedValue1"));
@@ -391,7 +393,7 @@ public class TableRowToStorageApiProtoIT {
                     .setTableId(table)
                     .setDatasetId(BIG_QUERY_DATASET_ID)
                     .setProjectId(PROJECT)));
-    return PROJECT + ":" + BIG_QUERY_DATASET_ID + "." + table;
+    return PROJECT + "." + BIG_QUERY_DATASET_ID + "." + table;
   }
 
   private static void runPipeline(String tableSpec, Iterable<TableRow> tableRows) {